Ejemplo n.º 1
0
def compute_overlap_and_matches(path1,
                                path2,
                                max_distance_overlap,
                                reciprocity=False,
                                num_pos=1):
    data1 = torch.load(path1)
    data2 = torch.load(path2)

    # we can use ball query on cpu because the points are sorted
    # print(len(data1.pos), len(data2.pos), max_distance_overlap)
    pair, dist = ball_query(data2.pos,
                            data1.pos,
                            radius=max_distance_overlap,
                            max_num=num_pos,
                            mode=1)
    pair = filter_pair(pair, dist)
    overlap = [pair.shape[0] / len(data1.pos)]
    if reciprocity:
        pair2, dist2 = ball_query(data1.pos,
                                  data2.pos,
                                  radius=max_distance_overlap,
                                  max_num=num_pos,
                                  mode=1)
        pair2 = filter_pair(pair2, dist2)
        overlap.append(pair2.shape[0] / len(data2.pos))
    # overlap = pair.shape[0] / \
    #    (len(data1.pos) + len(data2.pos) - pair.shape[0])
    # print(pair)

    # print(path1, path2, "overlap=", overlap)
    output = dict(pair=pair,
                  path_source=path1,
                  path_target=path2,
                  overlap=overlap)
    return output
Ejemplo n.º 2
0
def compute_overlap_and_matches(data1,
                                data2,
                                max_distance_overlap,
                                reciprocity=False,
                                num_pos=1,
                                rot_gt=torch.eye(3)):

    # we can use ball query on cpu because the points are sorted
    # print(len(data1.pos), len(data2.pos), max_distance_overlap)
    pair, dist = ball_query(data2.pos.to(torch.float),
                            data1.pos.to(torch.float) @ rot_gt.T,
                            radius=max_distance_overlap,
                            max_num=num_pos,
                            mode=1,
                            sorted=True)
    pair = filter_pair(pair, dist)
    pair2 = []
    overlap = [pair.shape[0] / len(data1.pos)]
    if reciprocity:
        pair2, dist2 = ball_query(data1.pos.to(torch.float) @ rot_gt.T,
                                  data2.pos.to(torch.float),
                                  radius=max_distance_overlap,
                                  max_num=num_pos,
                                  mode=1,
                                  sorted=True)
        pair2 = filter_pair(pair2, dist2)
        overlap.append(pair2.shape[0] / len(data2.pos))
    # overlap = pair.shape[0] / \
    #    (len(data1.pos) + len(data2.pos) - pair.shape[0])
    # print(pair)

    # print(path1, path2, "overlap=", overlap)
    output = dict(pair=pair, pair2=pair2, overlap=overlap)
    return output
Ejemplo n.º 3
0
    def __call__(self, data):

        ind, dist = ball_query(data.pos, data.pos, radius=self.radius_nn, max_num=-1, mode=0)

        mask = (dist > 0).sum(1) > self.min_num
        data = apply_mask(data, mask, self.skip_keys)
        return data
Ejemplo n.º 4
0
    def __call__(self, data):

        if self.name_ind is None:
            ind, dist = ball_query(data.pos, self.centers,
                                   radius=self.radius,
                                   max_num=-1, mode=1)
        else:
            center = data.pos[data[self.name_ind].long()]
            ind, dist = ball_query(data.pos, center,
                                   radius=self.radius,
                                   max_num=-1, mode=1)
        ind = ind[dist[:, 0] > 0]
        mask = torch.ones(len(data.pos), dtype=torch.bool)
        mask[ind[:, 0]] = False
        data = apply_mask(data, mask)

        return data
Ejemplo n.º 5
0
 def __call__(self, data):
     i = torch.randint(0, len(data.pos), (1,))
     ind, dist = ball_query(data.pos, data.pos[i].view(1, 3), radius=self.radius, max_num=-1, mode=1)
     ind = ind[dist[:, 0] > 0]
     size_pos = len(data.pos)
     for k in data.keys:
         if size_pos == len(data[k]):
             data[k] = data[k][ind[:, 0]]
     return data
Ejemplo n.º 6
0
    def __call__(self, data):

        pos = data.pos
        list_ind = torch.randint(0, len(pos), (self.num_sphere,))

        ind, dist = ball_query(data.pos, data.pos[list_ind], radius=self.radius, max_num=-1, mode=1)
        ind = ind[dist[:, 0] > 0]
        mask = torch.ones(len(pos), dtype=torch.bool)
        mask[ind[:, 0]] = False
        data = apply_mask(data, mask)

        return data
Ejemplo n.º 7
0
    def unsupervised_preprocess(self, data_source_o, data_target_o):
        """
        same pairs for self supervised learning
        """
        len_col = 0

        while (len_col < self.min_points):
            # choose only one data augmentation randomly in the ss_transform (usually a crop)
            if (self.ss_transform is not None):
                n1 = np.random.randint(0, len(self.ss_transform.transforms))
                t1 = self.ss_transform.transforms[n1]
                n2 = np.random.randint(0, len(self.ss_transform.transforms))
                t2 = self.ss_transform.transforms[n2]
                data_source = t1(data_source_o.clone())
                data_target = t2(data_target_o.clone())
            else:
                data_source = data_source_o
                data_target = data_target_o
            pos = data_source.pos
            i = torch.randint(0, len(pos), (1, ))
            size_block = random.random() * (
                self.max_size_block -
                self.min_size_block) + self.min_size_block
            point = pos[i].view(1, 3)
            ind, dist = ball_query(point,
                                   pos,
                                   radius=size_block,
                                   max_num=-1,
                                   mode=1)
            _, col = ind[dist[:, 0] > 0].t()
            ind_t, dist_t = ball_query(data_target.pos,
                                       pos[col],
                                       radius=self.max_dist_overlap,
                                       max_num=1,
                                       mode=1)
            col_target, ind_col = ind_t[dist_t[:, 0] > 0].t()
            col = col[ind_col]
            new_pair = torch.stack((col, col_target)).T
            len_col = len(new_pair)
        return data_source, data_target, new_pair
    def __call__(self, data: Data, ind):

        pos = data.pos
        point = pos[ind].view(1, 3)
        ind, dist = ball_query(point, pos, radius=self.radius_patch, max_num=-1, mode=1)

        row, col = ind[dist[:, 0] > 0].t()
        patch = Data()
        for key in data.keys:
            if torch.is_tensor(data[key]):
                if torch.all(col < data[key].shape[0]):
                    patch[key] = data[key][col]

        return patch
Ejemplo n.º 9
0
    def __call__(self, data):

        data_c = self.grid_sampling(data.clone())
        list_ind = torch.randint(0, len(data_c.pos), (self.num_sphere,))
        center = data_c.pos[list_ind]
        pos = data.pos
        # list_ind = torch.randint(0, len(pos), (self.num_sphere,))

        ind, dist = ball_query(data.pos, center, radius=self.radius, max_num=-1, mode=1)
        ind = ind[dist[:, 0] >= 0]
        mask = torch.ones(len(pos), dtype=torch.bool)
        mask[ind[:, 0]] = False
        data = apply_mask(data, mask)

        return data
Ejemplo n.º 10
0
    def __call__(self, data):

        pos = data.pos.detach().cpu().numpy()
        ind, dist = ball_query(data.pos, data.pos,
                               radius=self.radius,
                               max_num=self.max_num, mode=0)
        mask = np.ones(len(pos), dtype=bool)
        mask = rw_mask(pos,
                       ind.detach().cpu().numpy(),
                       dist.detach().cpu().numpy(),
                       mask,
                       num_iter=self.num_iter,
                       random_ratio=self.dropout_ratio)

        data = apply_mask(data, mask, self.skip_keys)

        return data