def compute_overlap_and_matches(path1, path2, max_distance_overlap, reciprocity=False, num_pos=1): data1 = torch.load(path1) data2 = torch.load(path2) # we can use ball query on cpu because the points are sorted # print(len(data1.pos), len(data2.pos), max_distance_overlap) pair, dist = ball_query(data2.pos, data1.pos, radius=max_distance_overlap, max_num=num_pos, mode=1) pair = filter_pair(pair, dist) overlap = [pair.shape[0] / len(data1.pos)] if reciprocity: pair2, dist2 = ball_query(data1.pos, data2.pos, radius=max_distance_overlap, max_num=num_pos, mode=1) pair2 = filter_pair(pair2, dist2) overlap.append(pair2.shape[0] / len(data2.pos)) # overlap = pair.shape[0] / \ # (len(data1.pos) + len(data2.pos) - pair.shape[0]) # print(pair) # print(path1, path2, "overlap=", overlap) output = dict(pair=pair, path_source=path1, path_target=path2, overlap=overlap) return output
def compute_overlap_and_matches(data1, data2, max_distance_overlap, reciprocity=False, num_pos=1, rot_gt=torch.eye(3)): # we can use ball query on cpu because the points are sorted # print(len(data1.pos), len(data2.pos), max_distance_overlap) pair, dist = ball_query(data2.pos.to(torch.float), data1.pos.to(torch.float) @ rot_gt.T, radius=max_distance_overlap, max_num=num_pos, mode=1, sorted=True) pair = filter_pair(pair, dist) pair2 = [] overlap = [pair.shape[0] / len(data1.pos)] if reciprocity: pair2, dist2 = ball_query(data1.pos.to(torch.float) @ rot_gt.T, data2.pos.to(torch.float), radius=max_distance_overlap, max_num=num_pos, mode=1, sorted=True) pair2 = filter_pair(pair2, dist2) overlap.append(pair2.shape[0] / len(data2.pos)) # overlap = pair.shape[0] / \ # (len(data1.pos) + len(data2.pos) - pair.shape[0]) # print(pair) # print(path1, path2, "overlap=", overlap) output = dict(pair=pair, pair2=pair2, overlap=overlap) return output
def __call__(self, data): ind, dist = ball_query(data.pos, data.pos, radius=self.radius_nn, max_num=-1, mode=0) mask = (dist > 0).sum(1) > self.min_num data = apply_mask(data, mask, self.skip_keys) return data
def __call__(self, data): if self.name_ind is None: ind, dist = ball_query(data.pos, self.centers, radius=self.radius, max_num=-1, mode=1) else: center = data.pos[data[self.name_ind].long()] ind, dist = ball_query(data.pos, center, radius=self.radius, max_num=-1, mode=1) ind = ind[dist[:, 0] > 0] mask = torch.ones(len(data.pos), dtype=torch.bool) mask[ind[:, 0]] = False data = apply_mask(data, mask) return data
def __call__(self, data): i = torch.randint(0, len(data.pos), (1,)) ind, dist = ball_query(data.pos, data.pos[i].view(1, 3), radius=self.radius, max_num=-1, mode=1) ind = ind[dist[:, 0] > 0] size_pos = len(data.pos) for k in data.keys: if size_pos == len(data[k]): data[k] = data[k][ind[:, 0]] return data
def __call__(self, data): pos = data.pos list_ind = torch.randint(0, len(pos), (self.num_sphere,)) ind, dist = ball_query(data.pos, data.pos[list_ind], radius=self.radius, max_num=-1, mode=1) ind = ind[dist[:, 0] > 0] mask = torch.ones(len(pos), dtype=torch.bool) mask[ind[:, 0]] = False data = apply_mask(data, mask) return data
def unsupervised_preprocess(self, data_source_o, data_target_o): """ same pairs for self supervised learning """ len_col = 0 while (len_col < self.min_points): # choose only one data augmentation randomly in the ss_transform (usually a crop) if (self.ss_transform is not None): n1 = np.random.randint(0, len(self.ss_transform.transforms)) t1 = self.ss_transform.transforms[n1] n2 = np.random.randint(0, len(self.ss_transform.transforms)) t2 = self.ss_transform.transforms[n2] data_source = t1(data_source_o.clone()) data_target = t2(data_target_o.clone()) else: data_source = data_source_o data_target = data_target_o pos = data_source.pos i = torch.randint(0, len(pos), (1, )) size_block = random.random() * ( self.max_size_block - self.min_size_block) + self.min_size_block point = pos[i].view(1, 3) ind, dist = ball_query(point, pos, radius=size_block, max_num=-1, mode=1) _, col = ind[dist[:, 0] > 0].t() ind_t, dist_t = ball_query(data_target.pos, pos[col], radius=self.max_dist_overlap, max_num=1, mode=1) col_target, ind_col = ind_t[dist_t[:, 0] > 0].t() col = col[ind_col] new_pair = torch.stack((col, col_target)).T len_col = len(new_pair) return data_source, data_target, new_pair
def __call__(self, data: Data, ind): pos = data.pos point = pos[ind].view(1, 3) ind, dist = ball_query(point, pos, radius=self.radius_patch, max_num=-1, mode=1) row, col = ind[dist[:, 0] > 0].t() patch = Data() for key in data.keys: if torch.is_tensor(data[key]): if torch.all(col < data[key].shape[0]): patch[key] = data[key][col] return patch
def __call__(self, data): data_c = self.grid_sampling(data.clone()) list_ind = torch.randint(0, len(data_c.pos), (self.num_sphere,)) center = data_c.pos[list_ind] pos = data.pos # list_ind = torch.randint(0, len(pos), (self.num_sphere,)) ind, dist = ball_query(data.pos, center, radius=self.radius, max_num=-1, mode=1) ind = ind[dist[:, 0] >= 0] mask = torch.ones(len(pos), dtype=torch.bool) mask[ind[:, 0]] = False data = apply_mask(data, mask) return data
def __call__(self, data): pos = data.pos.detach().cpu().numpy() ind, dist = ball_query(data.pos, data.pos, radius=self.radius, max_num=self.max_num, mode=0) mask = np.ones(len(pos), dtype=bool) mask = rw_mask(pos, ind.detach().cpu().numpy(), dist.detach().cpu().numpy(), mask, num_iter=self.num_iter, random_ratio=self.dropout_ratio) data = apply_mask(data, mask, self.skip_keys) return data