def encode_crop(self, inputs, device, vol_bound=None): ''' Encode a crop to feature volumes Args: inputs (dict): input point cloud device (device): pytorch device vol_bound (dict): volume boundary ''' if vol_bound == None: vol_bound = self.vol_bound index = {} for fea in self.vol_bound['fea_type']: # crop the input point cloud mask_x = (inputs[:, :, 0] >= vol_bound['input_vol'][0][0]) &\ (inputs[:, :, 0] < vol_bound['input_vol'][1][0]) mask_y = (inputs[:, :, 1] >= vol_bound['input_vol'][0][1]) &\ (inputs[:, :, 1] < vol_bound['input_vol'][1][1]) mask_z = (inputs[:, :, 2] >= vol_bound['input_vol'][0][2]) &\ (inputs[:, :, 2] < vol_bound['input_vol'][1][2]) mask = mask_x & mask_y & mask_z p_input = inputs[mask] if p_input.shape[0] == 0: # no points in the current crop p_input = inputs.squeeze() ind = coord2index(p_input.clone(), vol_bound['input_vol'], reso=self.vol_bound['reso'], plane=fea) if fea == 'grid': ind[~mask] = self.vol_bound['reso']**3 else: ind[~mask] = self.vol_bound['reso']**2 else: ind = coord2index(p_input.clone(), vol_bound['input_vol'], reso=self.vol_bound['reso'], plane=fea) index[fea] = ind.unsqueeze(0) input_cur = add_key(p_input.unsqueeze(0), index, 'points', 'index', device=device) with torch.no_grad(): c = self.model.encode_inputs(input_cur) return c
def load(self, model_path, idx, vol): ''' Loads the data point. Args: model_path (str): path to model idx (int): ID of data point vol (dict): precomputed volume info ''' if self.multi_files is None: file_path = os.path.join(model_path, self.file_name) else: num = np.random.randint(self.multi_files) file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) pointcloud_dict = np.load(file_path) points = pointcloud_dict['points'].astype(np.float32) normals = pointcloud_dict['normals'].astype(np.float32) # add noise globally if self.transform is not None: data = {None: points, 'normals': normals} data = self.transform(data) points = data[None] # acquire the crop index ind_list = [] for i in range(3): ind_list.append((points[:, i] >= vol['input_vol'][0][i]) & (points[:, i] <= vol['input_vol'][1][i])) mask = ind_list[0] & ind_list[1] & ind_list[ 2] # points inside the input volume mask = ~mask # True means outside the boundary!! data['mask'] = mask points[mask] = 0.0 # calculate index of each point w.r.t. defined resolution index = {} for key in vol['plane_type']: index[key] = coord2index(points.copy(), vol['input_vol'], reso=vol['reso'], plane=key) if key == 'grid': index[key][:, mask] = vol['reso']**3 else: index[key][:, mask] = vol['reso']**2 data['ind'] = index return data