def test_multiscaleTransforms(self): samplers = [GridSampling(0.25), None, GridSampling(0.5)] search = [ RadiusNeighbourFinder(0.5, 100, ConvolutionFormat.PARTIAL_DENSE.value), RadiusNeighbourFinder(0.5, 150, ConvolutionFormat.PARTIAL_DENSE.value), RadiusNeighbourFinder(1, 200, ConvolutionFormat.PARTIAL_DENSE.value), ] upsampler = [KNNInterpolate(1), KNNInterpolate(1)] N = 10 x = np.linspace(0, 1, N) y = np.linspace(0, 1, N) xv, yv = np.meshgrid(x, y) pos = torch.tensor([xv.flatten(), yv.flatten(), np.zeros(N * N)]).T x = torch.ones_like(pos) d = Data(pos=pos, x=x).contiguous() ms_transform = MultiScaleTransform({ "sampler": samplers, "neighbour_finder": search, "upsample_op": upsampler }) transformed = ms_transform(d.clone()) npt.assert_almost_equal(transformed.x.numpy(), x.numpy()) npt.assert_almost_equal(transformed.pos.numpy(), pos.numpy()) ms = transformed.multiscale npt.assert_almost_equal(ms[0].pos.numpy(), ms[1].pos.numpy()) npt.assert_almost_equal(ms[0].pos.numpy(), samplers[0](d.clone()).pos.numpy()) npt.assert_almost_equal(ms[2].pos.numpy(), samplers[2](ms[0].clone()).pos.numpy()) self.assertEqual(ms[0].__inc__("idx_neighboors", 0), pos.shape[0]) idx = search[0]( d.pos, ms[0].pos, torch.zeros((d.pos.shape[0]), dtype=torch.long), torch.zeros((ms[0].pos.shape[0]), dtype=torch.long), ) for i in range(len(ms[0].idx_neighboors)): self.assertEqual(set(ms[0].idx_neighboors[i].tolist()), set(idx[i].tolist())) self.assertEqual(ms[1].idx_neighboors.shape[1], 150) self.assertEqual(ms[2].idx_neighboors.shape[1], 200) upsample = transformed.upsample self.assertEqual(upsample[0].num_nodes, ms[1].num_nodes) self.assertEqual(upsample[1].num_nodes, pos.shape[0]) self.assertEqual(upsample[1].x_idx.max(), ms[0].num_nodes - 1) self.assertEqual(upsample[1].y_idx.max(), pos.shape[0] - 1) self.assertEqual(upsample[1].__inc__("x_idx", 0), ms[0].num_nodes) self.assertEqual(upsample[1].__inc__("y_idx", 0), pos.shape[0])
def test_precompute(self): pos = torch.tensor([[1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 0], [0.1, 0, 0]]) x = torch.tensor([0, 0, 0, 0, 1]).unsqueeze(-1) support = Data(x=x, pos=pos) query = GridSampling3D(1)(support.clone()) interpolate = KNNInterpolate(1) up = interpolate.precompute(query, support) self.assertEqual(up.num_nodes, 5) self.assertEqual(up.x_idx[4], up.x_idx[3])
def test_compute(self): npoints = 100 pos = torch.randn((npoints, 3)) x = torch.randn((npoints, 4)) support = Data(x=x, pos=pos) query = Data(x=torch.randn((npoints // 2, 4)), pos=torch.randn((npoints // 2, 3))) interpolate = KNNInterpolate(3) precomputed = interpolate.precompute(query, support) gt = interpolate(query, support) pre = interpolate(query, support, precomputed=precomputed) torch.testing.assert_allclose(gt, pre)