def get_fragment(self, idx): data_source, data_target, new_pair = self.get_raw_pair(idx) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) if (hasattr(data_source, "multiscale")): batch = MultiScalePair.make_pair(data_source, data_target) else: batch = Pair.make_pair(data_source, data_target) if self.is_online_matching: new_match = compute_overlap_and_matches(Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap) batch.pair_ind = torch.from_numpy(new_match["pair"].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs if not self.use_fps or (float(num_pos_pairs) / len(batch.pair_ind) >= 1): rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] else: rand_ind = fps_sampling(batch.pair_ind, batch.pos, num_pos_pairs) batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) if (len(batch.pair_ind) == 0): print("Warning") return batch.contiguous()
def get_fragment(self, idx): match = np.load(osp.join(self.path_match, 'matches{:06d}.npy'.format(idx)), allow_pickle=True).item() data_source = torch.load(match['path_source']) data_target = torch.load(match['path_target']) # new_pair = compute_subsampled_matches(data_source, data_target,self.voxel_size_search,self.max_dist_overlap) new_pair = torch.from_numpy(match['pair']) if (self.transform is not None): data_source = self.transform(data_source) data_target = self.transform(data_target) batch = Pair.make_pair(data_source, data_target) if (self.is_online_matching): new_match = compute_overlap_and_matches(Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap) batch.pair_ind = torch.from_numpy(new_match['pair'].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) return batch.contiguous().to(torch.float)
def get_fragment(self, idx): match = np.load(osp.join(self.path_match, "matches{:06d}.npy".format(idx)), allow_pickle=True).item() data_source = torch.load(match["path_source"]).to(torch.float) data_target = torch.load(match["path_target"]).to(torch.float) new_pair = torch.from_numpy(match["pair"]) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) if(hasattr(data_source, "multiscale")): batch = MultiScalePair.make_pair(data_source, data_target) else: batch = Pair.make_pair(data_source, data_target) if self.is_online_matching: new_match = compute_overlap_and_matches( Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap ) batch.pair_ind = torch.from_numpy(new_match["pair"].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) return batch.contiguous()
def test_simple(self): ind_source = torch.tensor([1, 2, 5]) ind_target = torch.tensor([0, 5, 6]) data_s = Data(pos=torch.randn(3, 3), origin_id=ind_source) data_t = Data(pos=torch.randn(3, 3), origin_id=ind_target) pair = torch.tensor([[0, 2], [1, 3], [2, 0], [3, 1]]) res = tracked_matches(data_s, data_t, pair) expected = np.array([[1, 0]]) npt.assert_array_almost_equal(res.detach().cpu().numpy(), expected)