def predict_crop_occ(self, pi, c, vol_bound=None, **kwargs): ''' Predict occupancy values for a crop Args: pi (dict): query points c (tensor): encoded feature volumes vol_bound (dict): volume boundary ''' occ_hat = pi.new_empty((pi.shape[0])) if pi.shape[0] == 0: return occ_hat pi_in = pi.unsqueeze(0) pi_in = {'p': pi_in} p_n = {} for key in self.vol_bound['fea_type']: # projected coordinates normalized to the range of [0, 1] p_n[key] = normalize_coord(pi.clone(), vol_bound['input_vol'], plane=key).unsqueeze(0).to(self.device) pi_in['p_n'] = p_n # predict occupancy of the current crop with torch.no_grad(): occ_cur = self.model.decode(pi_in, c, **kwargs).logits occ_hat = occ_cur.squeeze(0) return occ_hat
def eval_points(self, p, c=None, vol_bound=None, **kwargs): ''' Evaluates the occupancy values for the points. Args: p (tensor): points c (tensor): encoded feature volumes ''' p_split = torch.split(p, self.points_batch_size) occ_hats = [] for pi in p_split: if self.input_type == 'pointcloud_crop': if self.vol_bound is not None: # sliding-window manner occ_hat = self.predict_crop_occ( pi, c, vol_bound=vol_bound, **kwargs) occ_hats.append(occ_hat) else: # entire scene pi_in = pi.unsqueeze(0).to(self.device) pi_in = {'p': pi_in} p_n = {} for key in c.keys(): # normalized to the range of [0, 1] p_n[key] = normalize_coord( pi.clone(), self.input_vol, plane=key).unsqueeze(0).to(self.device) pi_in['p_n'] = p_n with torch.no_grad(): occ_hat = self.model.decode(pi_in, c, **kwargs).logits occ_hats.append(occ_hat.squeeze(0).detach().cpu()) else: pi = pi.unsqueeze(0).to(self.device) with torch.no_grad(): occ_hat = self.model.decode(pi, c, **kwargs).logits occ_hats.append(occ_hat.squeeze(0).detach().cpu()) occ_hat = torch.cat(occ_hats, dim=0) return occ_hat
def load(self, model_path, idx, vol): ''' Loads the data point. Args: model_path (str): path to model idx (int): ID of data point vol (dict): precomputed volume info ''' if self.multi_files is None: file_path = os.path.join(model_path, self.file_name) else: num = np.random.randint(self.multi_files) file_path = os.path.join(model_path, self.file_name, '%s_%02d.npz' % (self.file_name, num)) points_dict = np.load(file_path) points = points_dict['points'] # Break symmetry if given in float16: if points.dtype == np.float16: points = points.astype(np.float32) points += 1e-4 * np.random.randn(*points.shape) occupancies = points_dict['occupancies'] if self.unpackbits: occupancies = np.unpackbits(occupancies)[:points.shape[0]] occupancies = occupancies.astype(np.float32) # acquire the crop ind_list = [] for i in range(3): ind_list.append((points[:, i] >= vol['query_vol'][0][i]) & (points[:, i] <= vol['query_vol'][1][i])) ind = ind_list[0] & ind_list[1] & ind_list[2] data = { None: points[ind], 'occ': occupancies[ind], } if self.transform is not None: data = self.transform(data) # calculate normalized coordinate w.r.t. defined query volume p_n = {} for key in vol['plane_type']: # projected coordinates normalized to the range of [0, 1] p_n[key] = normalize_coord(data[None].copy(), vol['input_vol'], plane=key) data['normalized'] = p_n return data