def test_line_edt2_2d_ch(self):
        grid = make_grid(100, 100)
        line = torch.tensor([[[[50., 40.], [50., 60]], [[40., 50.], [60.,
                                                                     50]]]])
        line = torch.cat([line] * bs, dim=0)

        raster = line_edt2(line, grid)
        self.assertEqual(raster.shape, (bs, 2, 100, 100))
        for i in range(40, 60):
            self.assertAlmostEqual(raster[0, 0, 50, i], 0.0)
            self.assertAlmostEqual(raster[0, 1, i, 50], 0.0)
Esempio n. 2
0
def render_lines(params, sigma2, grid):
    return exp(line_edt2(params, grid), sigma2).unsqueeze(0)
Esempio n. 3
0
    def create_edt2(self, lines):
        edt2 = line_edt2(lines, self.grid)

        return edt2