def test_linear_curve_edt2_bruteforce_3d(self):
        grid = make_grid(10, 10, 10)
        line = torch.tensor([[[4., 5., 5], [6., 5, 5]]])
        line = torch.cat([line] * bs, dim=0)

        raster = curve_edt2_bruteforce(line,
                                       grid,
                                       iters=1,
                                       slices=10,
                                       cfcn=linear_bezier)
        self.assertEqual(raster.shape, (bs, 10, 10, 10))
        for i in range(4, 6):
            self.assertAlmostEqual(raster[0, i, 5, 5].item(), 0.0, delta=0.001)
    def test_curve_edt2_bruteforce_2d(self):
        grid = make_grid(100, 100)
        line = torch.tensor([[[50., 30.], [50., 40.], [50., 60.], [50., 70]]])
        line = torch.cat([line] * bs, dim=0)

        for f in [
                cubic_bezier, quadratic_bezier, centripetal_catmull_rom_spline
        ]:
            raster = curve_edt2_bruteforce(line,
                                           grid,
                                           iters=5,
                                           slices=20,
                                           cfcn=f)
            self.assertEqual(raster.shape, (bs, 100, 100))
    def test_linear_curve_edt2_bruteforce_2d(self):
        grid = make_grid(100, 100)
        line = torch.tensor([[[50., 40.], [50., 60]]])
        line = torch.cat([line] * bs, dim=0)

        raster = curve_edt2_bruteforce(line,
                                       grid,
                                       iters=5,
                                       slices=20,
                                       cfcn=linear_bezier)
        self.assertEqual(raster.shape, (bs, 100, 100))
        for i in range(40, 60):
            self.assertAlmostEqual(raster[0, 50, i].item(), 0.0)
        for i in range(40, 60):
            self.assertAlmostEqual(raster[0, 49, i].item(), 1.0)
Example #4
0
 def create_edt2(self, lines):
     if self.edt_approx == 'polyline':
         edt2 = curve_edt2_polyline(lines,
                                    self.grid,
                                    10,
                                    cfcn=catmull_rom_spline)
     elif self.edt_approx == 'bruteforce':
         edt2 = curve_edt2_bruteforce(lines,
                                      self.grid,
                                      2,
                                      10,
                                      cfcn=catmull_rom_spline)
     else:
         raise
     return edt2