コード例 #1
0
    def get_fragment(self, idx):

        data_source, data_target, new_pair = self.get_raw_pair(idx)
        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        if (hasattr(data_source, "multiscale")):
            batch = MultiScalePair.make_pair(data_source, data_target)
        else:
            batch = Pair.make_pair(data_source, data_target)
        if self.is_online_matching:
            new_match = compute_overlap_and_matches(Data(pos=data_source.pos),
                                                    Data(pos=data_target.pos),
                                                    self.max_dist_overlap)
            batch.pair_ind = torch.from_numpy(new_match["pair"].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        if not self.use_fps or (float(num_pos_pairs) / len(batch.pair_ind) >=
                                1):
            rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        else:
            rand_ind = fps_sampling(batch.pair_ind, batch.pos, num_pos_pairs)
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        if (len(batch.pair_ind) == 0):
            print("Warning")
        return batch.contiguous()
コード例 #2
0
    def get_fragment(self, idx):

        match = np.load(osp.join(self.path_match,
                                 'matches{:06d}.npy'.format(idx)),
                        allow_pickle=True).item()
        data_source = torch.load(match['path_source'])
        data_target = torch.load(match['path_target'])
        # new_pair = compute_subsampled_matches(data_source, data_target,self.voxel_size_search,self.max_dist_overlap)
        new_pair = torch.from_numpy(match['pair'])

        if (self.transform is not None):
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        batch = Pair.make_pair(data_source, data_target)
        if (self.is_online_matching):
            new_match = compute_overlap_and_matches(Data(pos=data_source.pos),
                                                    Data(pos=data_target.pos),
                                                    self.max_dist_overlap)
            batch.pair_ind = torch.from_numpy(new_match['pair'].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        return batch.contiguous().to(torch.float)
コード例 #3
0
    def get_fragment(self, idx):

        match = np.load(osp.join(self.path_match, "matches{:06d}.npy".format(idx)), allow_pickle=True).item()
        data_source = torch.load(match["path_source"]).to(torch.float)
        data_target = torch.load(match["path_target"]).to(torch.float)
        new_pair = torch.from_numpy(match["pair"])

        if self.transform is not None:
            data_source = self.transform(data_source)
            data_target = self.transform(data_target)

        if(hasattr(data_source, "multiscale")):
            batch = MultiScalePair.make_pair(data_source, data_target)
        else:
            batch = Pair.make_pair(data_source, data_target)
        if self.is_online_matching:
            new_match = compute_overlap_and_matches(
                Data(pos=data_source.pos), Data(pos=data_target.pos), self.max_dist_overlap
            )
            batch.pair_ind = torch.from_numpy(new_match["pair"].copy())
        else:
            pair = tracked_matches(data_source, data_target, new_pair)
            batch.pair_ind = pair

        num_pos_pairs = len(batch.pair_ind)
        if self.num_pos_pairs < len(batch.pair_ind):
            num_pos_pairs = self.num_pos_pairs

        rand_ind = torch.randperm(len(batch.pair_ind))[:num_pos_pairs]
        batch.pair_ind = batch.pair_ind[rand_ind]
        batch.size_pair_ind = torch.tensor([num_pos_pairs])
        return batch.contiguous()
コード例 #4
0
    def test_simple(self):

        ind_source = torch.tensor([1, 2, 5])
        ind_target = torch.tensor([0, 5, 6])
        data_s = Data(pos=torch.randn(3, 3), origin_id=ind_source)
        data_t = Data(pos=torch.randn(3, 3), origin_id=ind_target)
        pair = torch.tensor([[0, 2], [1, 3], [2, 0], [3, 1]])

        res = tracked_matches(data_s, data_t, pair)
        expected = np.array([[1, 0]])
        npt.assert_array_almost_equal(res.detach().cpu().numpy(), expected)