Example #1
0
    def test_multiscaleTransforms(self):
        samplers = [GridSampling(0.25), None, GridSampling(0.5)]
        search = [
            RadiusNeighbourFinder(0.5, 100,
                                  ConvolutionFormat.PARTIAL_DENSE.value),
            RadiusNeighbourFinder(0.5, 150,
                                  ConvolutionFormat.PARTIAL_DENSE.value),
            RadiusNeighbourFinder(1, 200,
                                  ConvolutionFormat.PARTIAL_DENSE.value),
        ]
        upsampler = [KNNInterpolate(1), KNNInterpolate(1)]

        N = 10
        x = np.linspace(0, 1, N)
        y = np.linspace(0, 1, N)
        xv, yv = np.meshgrid(x, y)

        pos = torch.tensor([xv.flatten(), yv.flatten(), np.zeros(N * N)]).T
        x = torch.ones_like(pos)
        d = Data(pos=pos, x=x).contiguous()
        ms_transform = MultiScaleTransform({
            "sampler": samplers,
            "neighbour_finder": search,
            "upsample_op": upsampler
        })

        transformed = ms_transform(d.clone())
        npt.assert_almost_equal(transformed.x.numpy(), x.numpy())
        npt.assert_almost_equal(transformed.pos.numpy(), pos.numpy())

        ms = transformed.multiscale
        npt.assert_almost_equal(ms[0].pos.numpy(), ms[1].pos.numpy())
        npt.assert_almost_equal(ms[0].pos.numpy(),
                                samplers[0](d.clone()).pos.numpy())
        npt.assert_almost_equal(ms[2].pos.numpy(),
                                samplers[2](ms[0].clone()).pos.numpy())

        self.assertEqual(ms[0].__inc__("idx_neighboors", 0), pos.shape[0])
        idx = search[0](
            d.pos,
            ms[0].pos,
            torch.zeros((d.pos.shape[0]), dtype=torch.long),
            torch.zeros((ms[0].pos.shape[0]), dtype=torch.long),
        )
        for i in range(len(ms[0].idx_neighboors)):
            self.assertEqual(set(ms[0].idx_neighboors[i].tolist()),
                             set(idx[i].tolist()))
        self.assertEqual(ms[1].idx_neighboors.shape[1], 150)
        self.assertEqual(ms[2].idx_neighboors.shape[1], 200)

        upsample = transformed.upsample
        self.assertEqual(upsample[0].num_nodes, ms[1].num_nodes)
        self.assertEqual(upsample[1].num_nodes, pos.shape[0])
        self.assertEqual(upsample[1].x_idx.max(), ms[0].num_nodes - 1)
        self.assertEqual(upsample[1].y_idx.max(), pos.shape[0] - 1)
        self.assertEqual(upsample[1].__inc__("x_idx", 0), ms[0].num_nodes)
        self.assertEqual(upsample[1].__inc__("y_idx", 0), pos.shape[0])
    def test_precompute(self):
        pos = torch.tensor([[1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 0], [0.1, 0, 0]])
        x = torch.tensor([0, 0, 0, 0, 1]).unsqueeze(-1)
        support = Data(x=x, pos=pos)

        query = GridSampling3D(1)(support.clone())

        interpolate = KNNInterpolate(1)
        up = interpolate.precompute(query, support)
        self.assertEqual(up.num_nodes, 5)
        self.assertEqual(up.x_idx[4], up.x_idx[3])
    def test_compute(self):
        npoints = 100
        pos = torch.randn((npoints, 3))
        x = torch.randn((npoints, 4))
        support = Data(x=x, pos=pos)
        query = Data(x=torch.randn((npoints // 2, 4)), pos=torch.randn((npoints // 2, 3)))

        interpolate = KNNInterpolate(3)
        precomputed = interpolate.precompute(query, support)

        gt = interpolate(query, support)
        pre = interpolate(query, support, precomputed=precomputed)

        torch.testing.assert_allclose(gt, pre)