Exemple #1
0
def test_fps_start_idx():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
    res = farthest_point_sampler(x, sample_points, start_idx=0)
    assert th.any(res[:, 0] == 0)
Exemple #2
0
def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = th.tensor(np.random.uniform(size=(batch_size, int(N / batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.to(ctx)
    res = farthest_point_sampler(x, sample_points)
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0
Exemple #3
0
    def forward(self, pos, feat):
        centroids = farthest_point_sampler(pos, self.n_points)
        g = self.frnn_graph(pos, centroids, feat)
        g.update_all(self.message, self.conv)

        mask = g.ndata['center'] == 1
        pos_dim = g.ndata['pos'].shape[-1]
        feat_dim = g.ndata['new_feat'].shape[-1]
        pos_res = g.ndata['pos'][mask].view(self.batch_size, -1, pos_dim)
        feat_res = g.ndata['new_feat'][mask].view(
            self.batch_size, -1, feat_dim)
        return pos_res, feat_res
Exemple #4
0
def test_fps():
    N = 1000
    batch_size = 5
    sample_points = 10
    x = mx.nd.array(np.random.uniform(size=(batch_size, int(N/batch_size), 3)))
    ctx = F.ctx()
    if F.gpu_ctx():
        x = x.as_in_context(ctx)
    res = farthest_point_sampler(x, sample_points)
    assert res.shape[0] == batch_size
    assert res.shape[1] == sample_points
    assert res.sum() > 0
Exemple #5
0
    def forward(self, pos, feat):
        centroids = farthest_point_sampler(pos, self.npoints)
        feat_res_list = []

        for i in range(self.group_size):
            g = self.frnn_graph_list[i](pos, centroids, feat)
            g.update_all(self.message_list[i], self.conv_list[i])
            mask = g.ndata['center'] == 1
            pos_dim = g.ndata['pos'].shape[-1]
            feat_dim = g.ndata['new_feat'].shape[-1]
            if i == 0:
                pos_res = g.ndata['pos'][mask].view(self.batch_size, -1,
                                                    pos_dim)
            feat_res = g.ndata['new_feat'][mask].view(self.batch_size, -1,
                                                      feat_dim)
            feat_res_list.append(feat_res)

        feat_res = torch.cat(feat_res_list, 2)
        return pos_res, feat_res