def get_generator(model, cfg, device, **kwargs): ''' Returns the generator object. Args: model (nn.Module): Occupancy Network model cfg (dict): imported yaml config device (device): pytorch device ''' if cfg['data']['input_type'] == 'pointcloud_crop': # calculate the volume boundary query_vol_metric = cfg['data']['padding'] + 1 unit_size = cfg['data']['unit_size'] recep_field = 2**( cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) if 'unet' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet_kwargs']['depth'] elif 'unet3d' in cfg['model']['encoder_kwargs']: depth = cfg['model']['encoder_kwargs']['unet3d_kwargs'][ 'num_levels'] vol_info = decide_total_volume_range(query_vol_metric, recep_field, unit_size, depth) grid_reso = cfg['data']['query_vol_size'] + recep_field - 1 grid_reso = update_reso(grid_reso, depth) query_vol_size = cfg['data']['query_vol_size'] * unit_size input_vol_size = grid_reso * unit_size # only for the sliding window case vol_bound = None if cfg['generation']['sliding_window']: vol_bound = { 'query_crop_size': query_vol_size, 'input_crop_size': input_vol_size, 'fea_type': cfg['model']['encoder_kwargs']['plane_type'], 'reso': grid_reso } else: vol_bound = None vol_info = None generator = generation.Generator3D( model, device=device, threshold=cfg['test']['threshold'], resolution0=cfg['generation']['resolution_0'], upsampling_steps=cfg['generation']['upsampling_steps'], sample=cfg['generation']['use_sampling'], refinement_step=cfg['generation']['refinement_step'], simplify_nfaces=cfg['generation']['simplify_nfaces'], input_type=cfg['data']['input_type'], padding=cfg['data']['padding'], vol_info=vol_info, vol_bound=vol_bound, ) return generator
def get_vol_info(self, model_path): ''' Get crop information Args: model_path (str): path to the current data ''' query_vol_size = self.cfg['data']['query_vol_size'] unit_size = self.cfg['data']['unit_size'] field_name = self.cfg['data']['pointcloud_file'] plane_type = self.cfg['model']['encoder_kwargs']['plane_type'] recep_field = 2**( self.cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) if self.cfg['data']['multi_files'] is None: file_path = os.path.join(model_path, field_name) else: num = np.random.randint(self.cfg['data']['multi_files']) file_path = os.path.join(model_path, field_name, '%s_%02d.npz' % (field_name, num)) points_dict = np.load(file_path) p = points_dict['points'] if self.split == 'train': # randomly sample a point as the center of input/query volume p_c = [ np.random.uniform(p[:, i].min(), p[:, i].max()) for i in range(3) ] # p_c = [np.random.uniform(-0.55, 0.55) for i in range(3)] p_c = np.array(p_c).astype(np.float32) reso = query_vol_size + recep_field - 1 # make sure the defined reso can be properly processed by UNet reso = update_reso(reso, self.depth) input_vol_metric = reso * unit_size query_vol_metric = query_vol_size * unit_size # bound for the volumes lb_input_vol, ub_input_vol = p_c - input_vol_metric / 2, p_c + input_vol_metric / 2 lb_query_vol, ub_query_vol = p_c - query_vol_metric / 2, p_c + query_vol_metric / 2 input_vol = [lb_input_vol, ub_input_vol] query_vol = [lb_query_vol, ub_query_vol] else: reso = self.total_reso input_vol = self.total_input_vol query_vol = self.total_query_vol vol_info = { 'plane_type': plane_type, 'reso': reso, 'input_vol': input_vol, 'query_vol': query_vol } return vol_info
def get_model(cfg, device=None, dataset=None, **kwargs): ''' Return the Occupancy Network model. Args: cfg (dict): imported yaml config device (device): pytorch device dataset (dataset): dataset ''' decoder = cfg['model']['decoder'] encoder = cfg['model']['encoder'] dim = cfg['data']['dim'] c_dim = cfg['model']['c_dim'] decoder_kwargs = cfg['model']['decoder_kwargs'] encoder_kwargs = cfg['model']['encoder_kwargs'] padding = cfg['data']['padding'] # for pointcloud_crop try: encoder_kwargs['unit_size'] = cfg['data']['unit_size'] decoder_kwargs['unit_size'] = cfg['data']['unit_size'] except: pass # local positional encoding if 'local_coord' in cfg['model'].keys(): encoder_kwargs['local_coord'] = cfg['model']['local_coord'] decoder_kwargs['local_coord'] = cfg['model']['local_coord'] if 'pos_encoding' in cfg['model']: encoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding'] decoder_kwargs['pos_encoding'] = cfg['model']['pos_encoding'] # update the feature volume/plane resolution if cfg['data']['input_type'] == 'pointcloud_crop': fea_type = cfg['model']['encoder_kwargs']['plane_type'] if (dataset.split == 'train') or (cfg['generation']['sliding_window']): recep_field = 2**( cfg['model']['encoder_kwargs']['unet3d_kwargs']['num_levels'] + 2) reso = cfg['data']['query_vol_size'] + recep_field - 1 if 'grid' in fea_type: encoder_kwargs['grid_resolution'] = update_reso( reso, dataset.depth) if bool(set(fea_type) & set(['xz', 'xy', 'yz'])): encoder_kwargs['plane_resolution'] = update_reso( reso, dataset.depth) # if dataset.split == 'val': #TODO run validation in room level during training else: if 'grid' in fea_type: encoder_kwargs['grid_resolution'] = dataset.total_reso if bool(set(fea_type) & set(['xz', 'xy', 'yz'])): encoder_kwargs['plane_resolution'] = dataset.total_reso decoder = models.decoder_dict[decoder](dim=dim, c_dim=c_dim, padding=padding, **decoder_kwargs) if encoder == 'idx': encoder = nn.Embedding(len(dataset), c_dim) elif encoder is not None: encoder = encoder_dict[encoder](dim=dim, c_dim=c_dim, padding=padding, **encoder_kwargs) else: encoder = None model = models.ConvolutionalOccupancyNetwork(decoder, encoder, device=device) return model