def test_pair_ind(self): data1 = Data(pos=torch.randn(100, 3)) data2 = Data(pos=torch.randn(114, 3)) pair1 = Pair.make_pair(data1, data2) pair1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0]]) data3 = Data(pos=torch.randn(102, 3)) data4 = Data(pos=torch.randn(104, 3)) pair2 = Pair.make_pair(data3, data4) pair2.pair_ind = torch.tensor([[0, 1], [45, 28], [101, 36], [98, 1], [14, 99], [34, 52], [1, 0]]) data5 = Data(pos=torch.randn(128, 3)) data6 = Data(pos=torch.randn(2102, 3)) pair3 = Pair.make_pair(data5, data6) pair3.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]]) batch = PairBatch.from_data_list([pair1, pair2, pair3]) expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29], [10, 110], [1, 0], [0 + 100, 1 + 114], [45 + 100, 28 + 114], [101 + 100, 36 + 114], [98 + 100, 1 + 114], [14 + 100, 99 + 114], [34 + 100, 52 + 114], [1 + 100, 0 + 114], [0 + 100 + 102, 1 + 114 + 104], [100 + 100 + 102, 1000 + 114 + 104], [1 + 100 + 102, 0 + 114 + 104]]).to(torch.long) npt.assert_almost_equal(batch.pair_ind.numpy(), expected_pair_ind.numpy())
def simple_test(self): nb_points_1 = 101 data_source = Pair( pos=torch.randn((nb_points_1, 3)), x=torch.randn((nb_points_1, 9)), norm=torch.randn((nb_points_1, 3)), random_feat=torch.randn((nb_points_1, 15)), ) nb_points_2 = 105 data_target = Pair( pos=torch.randn((nb_points_2, 3)), x=torch.randn((nb_points_2, 9)), norm=torch.randn((nb_points_2, 3)), random_feat=torch.randn((nb_points_2, 15)), ) b = Pair.make_pair(data_source, data_target) self.assertEqual(b.pos.size(), (nb_points_1 + nb_points_2, 3)) self.assertEqual(b.x.size(), (nb_points_1 + nb_points_2, 9)) print("pair:", b.pair) assert getattr(b, "pair", None) is not None self.assertEqual(b.pos_source.size(), (nb_points_1, 3)) self.assertEqual(b.x_source.size(), (nb_points_1, 9)) self.assertEqual(b.pos_target.size(), (nb_points_2, 3)) self.assertEqual(b.x_target.size(), (nb_points_2, 9))
def test_pair_batch(self): d1 = Data(x=torch.tensor([1]), pos=torch.tensor([1])) d2 = Data(x=torch.tensor([2]), pos=torch.tensor([4])) d3 = Data(x=torch.tensor([3]), pos=torch.tensor([9])) d4 = Data(x=torch.tensor([4]), pos=torch.tensor([16])) p1 = Pair.make_pair(d1, d2) p2 = Pair.make_pair(d3, d4) batch = PairBatch.from_data_list([p1, p2]) tt.assert_allclose(batch.x, torch.tensor([1, 3])) tt.assert_allclose(batch.pos, torch.tensor([1, 9])) tt.assert_allclose(batch.batch, torch.tensor([0, 1])) tt.assert_allclose(batch.x_target, torch.tensor([2, 4])) tt.assert_allclose(batch.pos_target, torch.tensor([4, 16])) tt.assert_allclose(batch.batch_target, torch.tensor([0, 1]))
def get_fragment(self, idx): data_source, data_target, new_pair = self.get_raw_pair(idx) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) if (hasattr(data_source, "multiscale")): batch = MultiScalePair.make_pair(data_source, data_target) else: batch = Pair.make_pair(data_source, data_target) if self.is_online_matching: new_match = compute_overlap_and_matches(Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap) batch.pair_ind = torch.from_numpy(new_match["pair"].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs if not self.use_fps or (float(num_pos_pairs) / len(batch.pair_ind) >= 1): rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] else: rand_ind = fps_sampling(batch.pair_ind, batch.pos, num_pos_pairs) batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) if (len(batch.pair_ind) == 0): print("Warning") return batch.contiguous()
def get_fragment(self, idx): match = np.load(osp.join(self.path_match, "matches{:06d}.npy".format(idx)), allow_pickle=True).item() data_source = torch.load(match["path_source"]).to(torch.float) data_target = torch.load(match["path_target"]).to(torch.float) new_pair = torch.from_numpy(match["pair"]) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) if(hasattr(data_source, "multiscale")): batch = MultiScalePair.make_pair(data_source, data_target) else: batch = Pair.make_pair(data_source, data_target) if self.is_online_matching: new_match = compute_overlap_and_matches( Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap ) batch.pair_ind = torch.from_numpy(new_match["pair"].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) return batch.contiguous()
def datalist(self): torch.manual_seed(0) datalist_source = [ Data( pos=torch.randn((self.num_points, 3)), x=self._feature, pair_ind=self._pair_ind, size_pair_ind=torch.tensor([len(self._pair_ind)]), ) for i in range(self.batch_size) ] datalist_target = [ Data( pos=torch.randn((self.num_points, 3)), x=self._feature, pair_ind=self._pair_ind, size_pair_ind=torch.tensor([len(self._pair_ind)]), ) for i in range(self.batch_size) ] if self._transform: datalist_source = [self._transform(d.clone()) for d in datalist_source] datalist_target = [self._transform(d.clone()) for d in datalist_target] if self._ms_transform: datalist_source = [self._ms_transform(d.clone()) for d in datalist_source] datalist_target = [self._ms_transform(d.clone()) for d in datalist_target] datalist = [Pair.make_pair(datalist_source[i], datalist_target[i]) for i in range(self.batch_size)] return datalist
def get_fragment(self, idx): match = np.load(osp.join(self.path_match, 'matches{:06d}.npy'.format(idx)), allow_pickle=True).item() data_source = torch.load(match['path_source']) data_target = torch.load(match['path_target']) # new_pair = compute_subsampled_matches(data_source, data_target,self.voxel_size_search,self.max_dist_overlap) new_pair = torch.from_numpy(match['pair']) if (self.transform is not None): data_source = self.transform(data_source) data_target = self.transform(data_target) batch = Pair.make_pair(data_source, data_target) if (self.is_online_matching): new_match = compute_overlap_and_matches(Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap) batch.pair_ind = torch.from_numpy(new_match['pair'].copy()) else: pair = tracked_matches(data_source, data_target, new_pair) batch.pair_ind = pair num_pos_pairs = len(batch.pair_ind) if self.num_pos_pairs < len(batch.pair_ind): num_pos_pairs = self.num_pos_pairs rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs] batch.pair_ind = batch.pair_ind[rand_ind] batch.size_pair_ind = torch.tensor([num_pos_pairs]) return batch.contiguous().to(torch.float)
def get_patch_offline(self, idx): data_source = torch.load( osp.join(self.path_data, "patches_source{:06d}.pt".format(idx))) data_target = torch.load( osp.join(self.path_data, "patches_target{:06d}.pt".format(idx))) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) batch = Pair.make_pair(data_source, data_target) batch = batch.contiguous().to(torch.float) return batch
def get_patch_offline(self, idx): data_source = torch.load(osp.join(self.path_data, "patches_source{:06d}.pt".format(idx))) data_target = torch.load(osp.join(self.path_data, "patches_target{:06d}.pt".format(idx))) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) if(hasattr(data_source, "multiscale")): batch = MultiScalePair.make_pair(data_source, data_target) else: batch = Pair.make_pair(data_source, data_target) return batch.contiguous()
def get_fragment(self, idx): match = np.load(osp.join(self.processed_dir, self.mode, 'matches', 'matches{:06d}.npy'.format(idx)), allow_pickle=True).item() print(match['path_source']) data_source = torch.load(match['path_source']) data_target = torch.load(match['path_target']) if (self.transform is not None): data_source = self.transform(data_source) data_target = self.transform(data_target) batch = Pair.make_pair(data_source, data_target) batch.y = torch.from_numpy(match['pair']) return batch.contiguous().to(torch.float)
def get_patch_online(self, idx): p_extractor = PatchExtractor(self.radius_patch) match = np.load(osp.join(self.path_data, "matches{:06d}.npy".format(idx)), allow_pickle=True).item() data_source = torch.load(match["path_source"]).to(torch.float) data_target = torch.load(match["path_target"]).to(torch.float) # select a random match on the list of match. # It cannot be 0 because matches are filtered. rand = np.random.randint(0, len(match["pair"])) data_source = p_extractor(data_source, match["pair"][rand][0]) data_target = p_extractor(data_target, match["pair"][rand][1]) if self.transform is not None: data_source = self.transform(data_source) data_target = self.transform(data_target) batch = Pair.make_pair(data_source, data_target) batch = batch.contiguous() return batch
def __call__(self, data): data_source, data_target = data.to_data() data_source = self.transform(data_source) data_target = self.transform(data_target) return Pair.make_pair(data_source, data_target)