def test_pair_ind(self):
        data1 = Data(pos=torch.randn(100, 3))
        data2 = Data(pos=torch.randn(114, 3))
        pair1 = Pair.make_pair(data1, data2)
        pair1.pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113], [54, 29],
                                       [10, 110], [1, 0]])
        data3 = Data(pos=torch.randn(102, 3))
        data4 = Data(pos=torch.randn(104, 3))
        pair2 = Pair.make_pair(data3, data4)
        pair2.pair_ind = torch.tensor([[0, 1], [45, 28], [101, 36], [98, 1],
                                       [14, 99], [34, 52], [1, 0]])
        data5 = Data(pos=torch.randn(128, 3))
        data6 = Data(pos=torch.randn(2102, 3))
        pair3 = Pair.make_pair(data5, data6)
        pair3.pair_ind = torch.tensor([[0, 1], [100, 1000], [1, 0]])

        batch = PairBatch.from_data_list([pair1, pair2, pair3])
        expected_pair_ind = torch.tensor([[0, 1], [99, 36], [98, 113],
                                          [54, 29], [10, 110], [1, 0],
                                          [0 + 100, 1 + 114],
                                          [45 + 100, 28 + 114],
                                          [101 + 100, 36 + 114],
                                          [98 + 100, 1 + 114],
                                          [14 + 100, 99 + 114],
                                          [34 + 100, 52 + 114],
                                          [1 + 100, 0 + 114],
                                          [0 + 100 + 102, 1 + 114 + 104],
                                          [100 + 100 + 102, 1000 + 114 + 104],
                                          [1 + 100 + 102,
                                           0 + 114 + 104]]).to(torch.long)
        npt.assert_almost_equal(batch.pair_ind.numpy(),
                                expected_pair_ind.numpy())
    def simple_test(self):
        nb_points_1 = 101
        data_source = Pair(
            pos=torch.randn((nb_points_1, 3)),
            x=torch.randn((nb_points_1, 9)),
            norm=torch.randn((nb_points_1, 3)),
            random_feat=torch.randn((nb_points_1, 15)),
        )
        nb_points_2 = 105
        data_target = Pair(
            pos=torch.randn((nb_points_2, 3)),
            x=torch.randn((nb_points_2, 9)),
            norm=torch.randn((nb_points_2, 3)),
            random_feat=torch.randn((nb_points_2, 15)),
        )

        b = Pair.make_pair(data_source, data_target)
        self.assertEqual(b.pos.size(), (nb_points_1 + nb_points_2, 3))
        self.assertEqual(b.x.size(), (nb_points_1 + nb_points_2, 9))
        print("pair:", b.pair)
        assert getattr(b, "pair", None) is not None
        self.assertEqual(b.pos_source.size(), (nb_points_1, 3))
        self.assertEqual(b.x_source.size(), (nb_points_1, 9))
        self.assertEqual(b.pos_target.size(), (nb_points_2, 3))
        self.assertEqual(b.x_target.size(), (nb_points_2, 9))
 def test_pair_batch(self):
     d1 = Data(x=torch.tensor([1]), pos=torch.tensor([1]))
     d2 = Data(x=torch.tensor([2]), pos=torch.tensor([4]))
     d3 = Data(x=torch.tensor([3]), pos=torch.tensor([9]))
     d4 = Data(x=torch.tensor([4]), pos=torch.tensor([16]))
     p1 = Pair.make_pair(d1, d2)
     p2 = Pair.make_pair(d3, d4)
     batch = PairBatch.from_data_list([p1, p2])
     tt.assert_allclose(batch.x, torch.tensor([1, 3]))
     tt.assert_allclose(batch.pos, torch.tensor([1, 9]))
     tt.assert_allclose(batch.batch, torch.tensor([0, 1]))
     tt.assert_allclose(batch.x_target, torch.tensor([2, 4]))
     tt.assert_allclose(batch.pos_target, torch.tensor([4, 16]))
     tt.assert_allclose(batch.batch_target, torch.tensor([0, 1]))
예제 #4
0
    def get_fragment(self, idx):

        data_source, data_target, new_pair = self.get_raw_pair(idx)
        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        if (hasattr(data_source, "multiscale")):
            batch = MultiScalePair.make_pair(data_source, data_target)
        else:
            batch = Pair.make_pair(data_source, data_target)
        if self.is_online_matching:
            new_match = compute_overlap_and_matches(Data(pos=data_source.pos),
                                                    Data(pos=data_target.pos),
                                                    self.max_dist_overlap)
            batch.pair_ind = torch.from_numpy(new_match["pair"].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        if not self.use_fps or (float(num_pos_pairs) / len(batch.pair_ind) >=
                                1):
            rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        else:
            rand_ind = fps_sampling(batch.pair_ind, batch.pos, num_pos_pairs)
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        if (len(batch.pair_ind) == 0):
            print("Warning")
        return batch.contiguous()
예제 #5
0
    def get_fragment(self, idx):

        match = np.load(osp.join(self.path_match, "matches{:06d}.npy".format(idx)), allow_pickle=True).item()
        data_source = torch.load(match["path_source"]).to(torch.float)
        data_target = torch.load(match["path_target"]).to(torch.float)
        new_pair = torch.from_numpy(match["pair"])

        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        if(hasattr(data_source, "multiscale")):
            batch = MultiScalePair.make_pair(data_source, data_target)
        else:
            batch = Pair.make_pair(data_source, data_target)
        if self.is_online_matching:
            new_match = compute_overlap_and_matches(
                Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap
            )
            batch.pair_ind = torch.from_numpy(new_match["pair"].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        return batch.contiguous()
예제 #6
0
 def datalist(self):
     torch.manual_seed(0)
     datalist_source = [
         Data(
             pos=torch.randn((self.num_points, 3)),
             x=self._feature,
             pair_ind=self._pair_ind,
             size_pair_ind=torch.tensor([len(self._pair_ind)]),
         )
         for i in range(self.batch_size)
     ]
     datalist_target = [
         Data(
             pos=torch.randn((self.num_points, 3)),
             x=self._feature,
             pair_ind=self._pair_ind,
             size_pair_ind=torch.tensor([len(self._pair_ind)]),
         )
         for i in range(self.batch_size)
     ]
     if self._transform:
         datalist_source = [self._transform(d.clone()) for d in datalist_source]
         datalist_target = [self._transform(d.clone()) for d in datalist_target]
     if self._ms_transform:
         datalist_source = [self._ms_transform(d.clone()) for d in datalist_source]
         datalist_target = [self._ms_transform(d.clone()) for d in datalist_target]
     datalist = [Pair.make_pair(datalist_source[i], datalist_target[i]) for i in range(self.batch_size)]
     return datalist
예제 #7
0
    def get_fragment(self, idx):

        match = np.load(osp.join(self.path_match,
                                 'matches{:06d}.npy'.format(idx)),
                        allow_pickle=True).item()
        data_source = torch.load(match['path_source'])
        data_target = torch.load(match['path_target'])
        # new_pair = compute_subsampled_matches(data_source, data_target,self.voxel_size_search,self.max_dist_overlap)
        new_pair = torch.from_numpy(match['pair'])

        if (self.transform is not None):
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        batch = Pair.make_pair(data_source, data_target)
        if (self.is_online_matching):
            new_match = compute_overlap_and_matches(Data(pos=data_source.pos),
                                                    Data(pos=data_target.pos),
                                                    self.max_dist_overlap)
            batch.pair_ind = torch.from_numpy(new_match['pair'].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        return batch.contiguous().to(torch.float)
 def get_patch_offline(self, idx):
     data_source = torch.load(
         osp.join(self.path_data, "patches_source{:06d}.pt".format(idx)))
     data_target = torch.load(
         osp.join(self.path_data, "patches_target{:06d}.pt".format(idx)))
     if self.transform is not None:
         data_source = self.transform(data_source)
         data_target = self.transform(data_target)
     batch = Pair.make_pair(data_source, data_target)
     batch = batch.contiguous().to(torch.float)
     return batch
예제 #9
0
    def get_patch_offline(self, idx):
        data_source = torch.load(osp.join(self.path_data, "patches_source{:06d}.pt".format(idx)))
        data_target = torch.load(osp.join(self.path_data, "patches_target{:06d}.pt".format(idx)))
        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        if(hasattr(data_source, "multiscale")):
            batch = MultiScalePair.make_pair(data_source, data_target)
        else:
            batch = Pair.make_pair(data_source, data_target)
        return batch.contiguous()
    def get_fragment(self, idx):

        match = np.load(osp.join(self.processed_dir, self.mode, 'matches',
                                 'matches{:06d}.npy'.format(idx)),
                        allow_pickle=True).item()

        print(match['path_source'])
        data_source = torch.load(match['path_source'])
        data_target = torch.load(match['path_target'])
        if (self.transform is not None):
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)
        batch = Pair.make_pair(data_source, data_target)
        batch.y = torch.from_numpy(match['pair'])
        return batch.contiguous().to(torch.float)
예제 #11
0
    def get_patch_online(self, idx):
        p_extractor = PatchExtractor(self.radius_patch)

        match = np.load(osp.join(self.path_data, "matches{:06d}.npy".format(idx)), allow_pickle=True).item()
        data_source = torch.load(match["path_source"]).to(torch.float)
        data_target = torch.load(match["path_target"]).to(torch.float)

        # select a random match on the list of match.
        # It cannot be 0 because matches are filtered.
        rand = np.random.randint(0, len(match["pair"]))

        data_source = p_extractor(data_source, match["pair"][rand][0])
        data_target = p_extractor(data_target, match["pair"][rand][1])

        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)
        batch = Pair.make_pair(data_source, data_target)
        batch = batch.contiguous()
        return batch
예제 #12
0
 def __call__(self, data):
     data_source, data_target = data.to_data()
     data_source = self.transform(data_source)
     data_target = self.transform(data_target)
     return Pair.make_pair(data_source, data_target)