def test_point_edt2_3d(self):
        grid = make_grid(50, 100, 100)
        point = torch.tensor([[25., 50., 50.]])
        point = torch.cat([point] * bs, dim=0)

        raster = point_edt2(point, grid)
        self.assertEqual(raster.shape, (bs, 50, 100, 100))
        self.assertAlmostEqual(raster[0, 25, 50, 50], 0.0)
    def test_point_edt2_2d_ch(self):
        grid = make_grid(100, 100)
        point = torch.tensor([[[50., 50.], [40., 40.]]])
        point = torch.cat([point] * bs, dim=0)

        raster = point_edt2(point, grid)
        self.assertEqual(raster.shape, (bs, 2, 100, 100))
        self.assertAlmostEqual(raster[0, 0, 50, 50], 0.0)
        self.assertAlmostEqual(raster[0, 1, 40, 40], 0.0)
Esempio n. 3
0
def render_points(params, sigma2, grid):
    return exp(point_edt2(params, grid), sigma2).unsqueeze(0)
Esempio n. 4
0
    def create_edt2(self, points):
        edt2 = point_edt2(points, self.grid)

        return edt2