def get_normalization_params_dataset(cat_id): from dids.core import BiKeyDataset def f(c): return _NormalizationParamsAutoSavingManager(c).get_saved_dataset() if isinstance(cat_id, (list, tuple)): dataset = BiKeyDataset({c: f(c) for c in cat_id}) else: dataset = f(cat_id) return dataset.map( lambda x: {k: np.array(v, dtype=np.float32) for k, v in x.items()})
def get_ffd_dataset(cat_ids, n=3, edge_length_threshold=None, n_samples=None): from dids.core import BiKeyDataset kwargs = dict( n=n, edge_length_threshold=edge_length_threshold, n_samples=n_samples) if isinstance(cat_ids, str): cat_ids = [cat_ids] datasets = {c: _get_ffd_dataset(c, **kwargs) for c in cat_ids} return BiKeyDataset(datasets)
def get_annotations_ffd_dataset(cat_id, n=3): if isinstance(cat_id, (list, tuple)): from dids.core import BiKeyDataset datasets = { c: _get_annotations_ffd_dataset(c, n=n) for c in cat_id} return BiKeyDataset(datasets) else: return _get_annotations_ffd_dataset(cat_id, n=n)
def get_template_mesh_dataset(cat_id, edge_length_threshold=None): if isinstance(cat_id, (list, tuple)): from dids.core import BiKeyDataset datasets = { c: _get_template_mesh_dataset(c, edge_length_threshold) for c in cat_id } return BiKeyDataset(datasets) else: return _get_template_mesh_dataset(cat_id, edge_length_threshold)
def get_point_cloud_dataset(cat_id, n_samples, example_ids=None, mode='r'): def f(c, e): return _get_point_cloud_dataset(c, n_samples, e, mode) if isinstance(cat_id, (tuple, list)): if example_ids is None: example_ids = tuple(None for _ in cat_id) datasets = {c: f(c, e) for c, e in zip(cat_id, example_ids)} return BiKeyDataset(datasets) else: return f(cat_id, example_ids)
def get_cloud_normal_dataset(cat_id, n_samples, example_ids=None, mode='r'): if not isinstance(cat_id, (tuple, list)): cat_id = [cat_id] example_ids = [example_ids] else: if example_ids is None: example_ids = [None for _ in cat_id] datasets = { c: _get_cloud_normal_dataset(c, n_samples, e, mode) for c, e in zip(cat_id, example_ids) } return BiKeyDataset(datasets)
def get_image_dataset(cat_ids, example_ids, view_indices, render_config=None): from shapenet.image import with_background from dids.core import BiKeyDataset if render_config is None: from shapenet.core.blender_renderings.config import RenderConfig render_config = RenderConfig() if isinstance(cat_ids, str): cat_ids = [cat_ids] example_ids = [example_ids] if isinstance(view_indices, int): view_indices = [view_indices] datasets = { c: render_config.get_multi_view_dataset( c, view_indices=view_indices, example_ids=eid) for c, eid in zip(cat_ids, example_ids)} dataset = BiKeyDataset(datasets).map( lambda image: with_background(image, 255)) dataset = dataset.map_keys( lambda key: (key[0], (key[1], key[2])), lambda key: (key[0],) + key[1]) return dataset
def get_cloud_dataset(cat_ids, example_ids, n_samples=16384, n_resamples=1024): import os from shapenet.core.point_clouds import PointCloudAutoSavingManager from util3d.point_cloud import sample_points from dids.core import BiKeyDataset if isinstance(cat_ids, str): cat_ids = [cat_ids] example_ids = [example_ids] datasets = {} for cat_id, e_ids in zip(cat_ids, example_ids): manager = PointCloudAutoSavingManager(cat_id, n_samples) if not os.path.isfile(manager.path): manager.save_all() datasets[cat_id] = manager.get_saving_dataset( mode='r').subset(e_ids) return BiKeyDataset(datasets).map( lambda x: sample_points(np.array(x, dtype=np.float32), n_resamples))