Exemplo n.º 1
0
    def test_init(self, dim):
        """Tests the safety checks"""
        """Expected to raise expections:"""
        with pytest.raises((ValueError, TypeError)):
            match_em.GreedyHungarianMatching()  # match_dims missing

        with pytest.warns(UserWarning):
            match_em.GreedyHungarianMatching(match_dims=dim,
                                             dist_lat=1.,
                                             dist_ax=1.,
                                             dist_vol=1.)  # unlikely comb.

        with pytest.warns(UserWarning):
            match_em.GreedyHungarianMatching(match_dims=dim)  # unlikely comb.

        with pytest.raises(ValueError):
            match_em.GreedyHungarianMatching(match_dims=1)

        with pytest.raises(ValueError):
            match_em.GreedyHungarianMatching(match_dims=4)
Exemplo n.º 2
0
    def __init__(self, *, raw_th, em_th, xy_unit: str, img_shape, ax_th=None, vol_th=None, lat_th=None,
                 p_aggregation='pbinom_cdf', px_size=None, match_dims=2, diag=0, pphotxyzbg_mapping=[0, 1, 2, 3, 4, -1],
                 num_workers=0, skip_th: (None, float) = None, return_format='batch-set', sanity_check=True):
        """

        Args:
            pphotxyzbg_mapping:
            raw_th:
            em_th:
            xy_unit:
            img_shape:
            ax_th:
            vol_th:
            lat_th:
            p_aggregation:
            px_size:
            match_dims:
            diag:
            num_workers:
            skip_th: relative fraction of the detection output to be on to skip post_processing.
                This is useful during training when the network has not yet converged and major parts of the
                detection output is white (i.e. non sparse detections).
            return_format:
            sanity_check:
        """
        super().__init__(xy_unit=xy_unit, px_size=px_size, return_format=return_format)

        self.raw_th = raw_th
        self.em_th = em_th
        self.p_aggregation = p_aggregation
        self.match_dims = match_dims
        self.num_workers = num_workers
        self.skip_th = skip_th

        self.pphotxyzbg_mapping = pphotxyzbg_mapping

        self._filter = match_emittersets.GreedyHungarianMatching(match_dims=match_dims, dist_lat=lat_th,
                                                                 dist_ax=ax_th, dist_vol=vol_th).filter

        self._bg_calculator = decode.simulation.background.BgPerEmitterFromBgFrame(filter_size=13, xextent=(0., 1.),
                                                                                   yextent=(0., 1.),
                                                                                   img_shape=img_shape)

        self._neighbor_kernel = torch.tensor([[diag, 1, diag], [1, 1, 1], [diag, 1, diag]]).float().view(1, 1, 3, 3)

        self._clusterer = AgglomerativeClustering(n_clusters=None,
                                                  distance_threshold=lat_th if self.match_dims == 2 else vol_th,
                                                  affinity='precomputed',
                                                  linkage='single')

        if sanity_check:
            self.sanity_check()
Exemplo n.º 3
0
 def test_filter_kernel_hand(self):
     """Setup"""
     matcher = match_em.GreedyHungarianMatching(match_dims=2,
                                                dist_lat=2.,
                                                dist_ax=None,
                                                dist_vol=None)
     """Run"""
     filter = matcher.filter(
         torch.zeros((4, 3)),
         torch.tensor([[1.9, 0., 0.], [2.1, 0., 0.], [0., 0., -5000.],
                       [1.5, 1.5, 0.]]))
     """Assert"""
     assert filter[:, 0].all()
     assert not filter[:, 1].all()
     assert filter[:, 2].all()
     assert not filter[:, 3].all()
Exemplo n.º 4
0
    def test_filter_kernel_statistical(self, dist_lat, dist_ax, dist_vol):
        """Setup"""
        matcher = match_em.GreedyHungarianMatching(match_dims=2,
                                                   dist_lat=dist_lat,
                                                   dist_ax=dist_ax,
                                                   dist_vol=dist_vol)

        n_out = 1000
        n_tar = 1200
        xyz_out = torch.rand((10, n_out, 3)) * torch.tensor(
            [500, 500, 1000]).unsqueeze(0).unsqueeze(0)  # batch implementation
        xyz_tar = torch.rand((10, n_tar, 3)) * torch.tensor(
            [500, 500, 1000]).unsqueeze(0).unsqueeze(0)  # batch implementation
        """Run"""
        act = matcher.filter(xyz_out, xyz_tar)  # active pairs
        """Asserts"""
        self.assert_dists(xyz_out, xyz_tar, dist_lat, dist_ax, dist_vol, act)
Exemplo n.º 5
0
    def test_match_kernel(self, match_dims, xyz_out, xyz_tar, expected):
        """Setup"""
        matcher = match_em.GreedyHungarianMatching(match_dims=match_dims,
                                                   dist_lat=1.,
                                                   dist_ax=2.,
                                                   dist_vol=None)
        """Run"""
        filter_mask = matcher.filter(xyz_out.unsqueeze(0),
                                     xyz_tar.unsqueeze(0))
        assignment = matcher._match_kernel(xyz_out, xyz_tar,
                                           filter_mask.squeeze(0))

        tp_ix_out, tp_match_ix_out = assignment[2:]
        tp_ix_exp, tp_match_ix_exp = expected
        """Assert"""
        assert (
            tp_ix_out.nonzero() == tp_ix_exp).all()  # boolean index in output
        assert (tp_match_ix_out.nonzero() == tp_match_ix_exp).all()
Exemplo n.º 6
0
 def matcher(self):
     return match_em.GreedyHungarianMatching(match_dims=2)