Beispiel #1
0
    def test_raises(self):
        with self.assertRaises(ValueError):
            nei_finder = MultiscaleRadiusNeighbourFinder([1], [3, 4])

        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        nei_finder = MultiscaleRadiusNeighbourFinder([1, 2], [3, 4])
        with self.assertRaises(ValueError):
            nei_finder(x, y, batch_x, batch_y, 10)
Beispiel #2
0
    def test_single_search(self):
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])

        nei_finder = MultiscaleRadiusNeighbourFinder(1, 4)
        self.assertEqual(nei_finder(x, y, batch_x, batch_y, 0)[1, :].shape[0], 4)
Beispiel #3
0
    def test_multi_num_search(self):
        x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
        batch_x = torch.tensor([0, 0, 0, 0])
        y = torch.Tensor([[-1, 0], [1, 0]])
        batch_y = torch.tensor([0, 0])
        nei_finder = MultiscaleRadiusNeighbourFinder(10, [3, 4])
        multiscale = []
        for i in range(2):
            multiscale.append(nei_finder(x, y, batch_x, batch_y, i))

        self.assertEqual(len(multiscale), 2)
        self.assertEqual(multiscale[0][1, :].shape[0], 6)
        self.assertEqual(multiscale[1][1, :].shape[0], 8)
Beispiel #4
0
    def __init__(self,
                 ratio=None,
                 radius=None,
                 radius_num_point=None,
                 down_conv_nn=None,
                 *args,
                 **kwargs):
        super(SAModule, self).__init__(
            FPSSampler(ratio=ratio),
            MultiscaleRadiusNeighbourFinder(
                radius, max_num_neighbors=radius_num_point), *args, **kwargs)

        local_nn = MLP(down_conv_nn) if down_conv_nn is not None else None

        self._conv = PointConv(local_nn=local_nn, global_nn=None)
        self._radius = radius
        self._ratio = ratio
        self._num_points = radius_num_point