def get_generator(model, cfg, device, **kwargs): ''' Returns the generator object. Args: model (nn.Module): Occupancy Network model cfg (dict): imported yaml config device (device): pytorch device ''' if cfg['data']['input_type'] == 'pointcloud_crop': # calculate the volume boundary query_vol_metric = cfg['data']['padding'] + 1 unit_size = cfg['data']['unit_size'] recep_field = 2**( cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) if 'unet' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth'] elif 'unet3d' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][ 'num_levels'] vol_info = decide_total_volume_range(query_vol_metric, recep_field, unit_size, depth) grid_reso = cfg['data']['query_vol_size'] + recep_field - 1 grid_reso = update_reso(grid_reso, depth) query_vol_size = cfg['data']['query_vol_size'] * unit_size input_vol_size = grid_reso * unit_size # only for the sliding window case vol_bound = None if cfg['generation']['sliding_window']: vol_bound = { 'query_crop_size': query_vol_size, 'input_crop_size': input_vol_size, 'fea_type': cfg['model']['encoder_kwargs']['plane_type'], 'reso': grid_reso } else: vol_bound = None vol_info = None generator = generation.Generator3D( model, device=device, threshold=cfg['test']['threshold'], resolution0=cfg['generation']['resolution_0'], upsampling_steps=cfg['generation']['upsampling_steps'], sample=cfg['generation']['use_sampling'], refinement_step=cfg['generation']['refinement_step'], simplify_nfaces=cfg['generation']['simplify_nfaces'], input_type=cfg['data']['input_type'], padding=cfg['data']['padding'], vol_info=vol_info, vol_bound=vol_bound, ) return generator
def __init__(self, dataset_folder, fields, split=None, categories=None, no_except=True, transform=None, cfg=None): ''' Initialization of the the 3D shape dataset. Args: dataset_folder (str): dataset folder fields (dict): dictionary of fields split (str): which split is used categories (list): list of categories to use no_except (bool): no exception transform (callable): transformation applied to data points cfg (yaml): config file ''' # Attributes self.dataset_folder = dataset_folder self.fields = fields self.no_except = no_except self.transform = transform self.cfg = cfg # If categories is None, use all subfolders if categories is None: categories = os.listdir(dataset_folder) categories = [ c for c in categories if os.path.isdir(os.path.join(dataset_folder, c)) ] # Read metadata file metadata_file = os.path.join(dataset_folder, 'metadata.yaml') if os.path.exists(metadata_file): with open(metadata_file, 'r') as f: self.metadata = yaml.load(f) else: self.metadata = {c: {'id': c, 'name': 'n/a'} for c in categories} # Set index for c_idx, c in enumerate(categories): self.metadata[c]['idx'] = c_idx # Get all models self.models = [] for c_idx, c in enumerate(categories): subpath = os.path.join(dataset_folder, c) if not os.path.isdir(subpath): logger.warning('Category %s does not exist in dataset.' % c) if split is None: self.models += [{ 'category': c, 'model': m } for m in [ d for d in os.listdir(subpath) if (os.path.isdir(os.path.join(subpath, d)) and d != '') ]] else: split_file = os.path.join(subpath, split + '.lst') with open(split_file, 'r') as f: models_c = f.read().split('\n') if '' in models_c: models_c.remove('') self.models += [{'category': c, 'model': m} for m in models_c] # precompute if self.cfg['data']['input_type'] == 'pointcloud_crop': self.split = split # proper resolution for feature plane/volume of the ENTIRE scene query_vol_metric = self.cfg['data']['padding'] + 1 unit_size = self.cfg['data']['unit_size'] recep_field = 2**( cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) if 'unet' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth'] elif 'unet3d' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][ 'num_levels'] self.depth = depth #! for sliding-window case, pass all points! if self.cfg['generation']['sliding_window']: self.total_input_vol, self.total_query_vol, self.total_reso = \ decide_total_volume_range(100000, recep_field, unit_size, depth) # contain the whole scene else: self.total_input_vol, self.total_query_vol, self.total_reso = \ decide_total_volume_range(query_vol_metric, recep_field, unit_size, depth)