def test_multiscaleTransforms(self): samplers = [GridSampling(0.25), None, GridSampling(0.5)] search = [ RadiusNeighbourFinder(0.5, 100, ConvolutionFormat.PARTIAL_DENSE.value), RadiusNeighbourFinder(0.5, 150, ConvolutionFormat.PARTIAL_DENSE.value), RadiusNeighbourFinder(1, 200, ConvolutionFormat.PARTIAL_DENSE.value), ] upsampler = [KNNInterpolate(1), KNNInterpolate(1)] N = 10 x = np.linspace(0, 1, N) y = np.linspace(0, 1, N) xv, yv = np.meshgrid(x, y) pos = torch.tensor([xv.flatten(), yv.flatten(), np.zeros(N * N)]).T x = torch.ones_like(pos) d = Data(pos=pos, x=x).contiguous() ms_transform = MultiScaleTransform({ "sampler": samplers, "neighbour_finder": search, "upsample_op": upsampler }) transformed = ms_transform(d.clone()) npt.assert_almost_equal(transformed.x.numpy(), x.numpy()) npt.assert_almost_equal(transformed.pos.numpy(), pos.numpy()) ms = transformed.multiscale npt.assert_almost_equal(ms[0].pos.numpy(), ms[1].pos.numpy()) npt.assert_almost_equal(ms[0].pos.numpy(), samplers[0](d.clone()).pos.numpy()) npt.assert_almost_equal(ms[2].pos.numpy(), samplers[2](ms[0].clone()).pos.numpy()) self.assertEqual(ms[0].__inc__("idx_neighboors", 0), pos.shape[0]) idx, _ = search[0]( d.pos, ms[0].pos, torch.zeros((d.pos.shape[0]), dtype=torch.long), torch.zeros((ms[0].pos.shape[0]), dtype=torch.long), ) for i in range(len(ms[0].idx_neighboors)): self.assertEqual(set(ms[0].idx_neighboors[i].tolist()), set(idx[i].tolist())) self.assertEqual(ms[1].idx_neighboors.shape[1], 150) self.assertEqual(ms[2].idx_neighboors.shape[1], 200) upsample = transformed.upsample self.assertEqual(upsample[0].num_nodes, ms[1].num_nodes) self.assertEqual(upsample[1].num_nodes, pos.shape[0]) self.assertEqual(upsample[1].x_idx.max(), ms[0].num_nodes - 1) self.assertEqual(upsample[1].y_idx.max(), pos.shape[0] - 1) self.assertEqual(upsample[1].__inc__("x_idx", 0), ms[0].num_nodes) self.assertEqual(upsample[1].__inc__("y_idx", 0), pos.shape[0])
def test_multiscaleTransforms(self): samplers = [GridSampling(0.25), None, GridSampling(0.5)] search = [ RadiusNeighbourFinder(0.5, 10, ConvolutionFormat.PARTIAL_DENSE.value[-1]), RadiusNeighbourFinder(0.5, 15, ConvolutionFormat.PARTIAL_DENSE.value[-1]), RadiusNeighbourFinder(1, 20, ConvolutionFormat.PARTIAL_DENSE.value[-1]), ] N = 10 x = np.linspace(0, 1, N) y = np.linspace(0, 1, N) xv, yv = np.meshgrid(x, y) pos = torch.tensor([xv.flatten(), yv.flatten(), np.zeros(N * N)]).T x = torch.ones_like(pos) d = Data(pos=pos, x=x).contiguous() ms_transform = MultiScaleTransform({ "sampler": samplers, "neighbour_finder": search }) transformed = ms_transform(d.clone()) npt.assert_almost_equal(transformed.x.numpy(), x.numpy()) npt.assert_almost_equal(transformed.pos.numpy(), pos.numpy()) ms = transformed.multiscale npt.assert_almost_equal(ms[0].pos.numpy(), ms[1].pos.numpy()) npt.assert_almost_equal(ms[0].pos.numpy(), samplers[0](d.clone()).pos.numpy()) npt.assert_almost_equal(ms[2].pos.numpy(), samplers[2](ms[0].clone()).pos.numpy()) self.assertEqual(ms[0].__inc__("idx_neighboors", 0), pos.shape[0]) idx, _ = search[0]( d.pos, ms[0].pos, torch.zeros((d.pos.shape[0]), dtype=torch.long), torch.zeros((ms[0].pos.shape[0]), dtype=torch.long), ) torch.testing.assert_allclose(ms[0].idx_neighboors, idx) self.assertEqual(ms[1].idx_neighboors.shape[1], 15) self.assertEqual(ms[2].idx_neighboors.shape[1], 20)
def set_strategies(self, model): strategies = model.get_sampling_and_search_strategies() transform = MultiScaleTransform(strategies) self._transform = transform
def set_strategies(self, model): strategies = model.get_spatial_ops() transform = MultiScaleTransform(strategies) self._transform = transform
def set_strategies(self, model, precompute_multi_scale=False): strategies = model.get_sampling_and_search_strategies() transform = MultiScaleTransform(strategies, precompute_multi_scale) self._set_multiscale_transform(transform)
def set_strategies(self, model): strategies = model.get_spatial_ops() transform = PairTransform(MultiScaleTransform(strategies)) self._set_multiscale_transform(transform)