Beispiel #1
0
    def write(self, em: EmitterSet, ix_low: int, ix_high: int):
        """Write emitter chunk to file."""
        if self._last_index == 'including':
            ix = f'_{ix_low}_{ix_high}'
        elif self._last_index == 'excluding':
            ix = f'_{ix_low}_{ix_high - 1}'
        else:
            raise ValueError

        fname = self._path / (self._name + ix + self._suffix)
        em.save(fname)
 def test_sanity_check(self):
     """Test correct shape of 1D tensors in EmitterSet"""
     xyz = torch.rand((10, 3))
     phot = torch.rand((10, 1))
     frame_ix = torch.rand(10)
     with pytest.raises(ValueError):
         EmitterSet(xyz, phot, frame_ix)
     """Test correct number of el. in EmitterSet."""
     xyz = torch.rand((10, 3))
     phot = torch.rand((11, 1))
     frame_ix = torch.rand(10)
     with pytest.raises(ValueError):
         EmitterSet(xyz, phot, frame_ix)
    def test_split_in_frames(self, em2d, em3d):
        splits = em2d.split_in_frames(None, None)
        assert splits.__len__() == 1

        splits = em3d.split_in_frames(None, None)
        assert em3d.frame_ix.max() - em3d.frame_ix.min() + 1 == len(splits)
        """Test negative numbers in Frame ix."""
        neg_frames = EmitterSet(torch.rand((3, 3)), torch.rand(3),
                                torch.tensor([-1, 0, 1]))
        splits = neg_frames.split_in_frames(None, None)
        assert splits.__len__() == 3
        splits = neg_frames.split_in_frames(0, None)
        assert splits.__len__() == 2
Beispiel #4
0
    def forward(self, output: emitter.EmitterSet, target: emitter.EmitterSet):
        """Setup split in frames. Determine the frame range automatically so as to cover everything."""
        if len(output) >= 1 and len(target) >= 1:
            frame_low = output.frame_ix.min() if output.frame_ix.min(
            ) < target.frame_ix.min() else target.frame_ix.min()
            frame_high = output.frame_ix.max() if output.frame_ix.max(
            ) > target.frame_ix.max() else target.frame_ix.max()
        elif len(output) >= 1:
            frame_low = output.frame_ix.min()
            frame_high = output.frame_ix.max()
        elif len(target) >= 1:
            frame_low = target.frame_ix.min()
            frame_high = target.frame_ix.max()
        else:
            return (emitter.EmptyEmitterSet(xy_unit=target.xyz,
                                            px_size=target.px_size), ) * 4

        out_pframe = output.split_in_frames(frame_low.item(),
                                            frame_high.item())
        tar_pframe = target.split_in_frames(frame_low.item(),
                                            frame_high.item())

        tpl, fpl, fnl, tpml = [], [], [], [
        ]  # true positive list, false positive list, false neg. ...
        """Match the emitters framewise"""
        for out_f, tar_f in zip(out_pframe, tar_pframe):
            filter_mask = self.filter(out_f.xyz_nm,
                                      tar_f.xyz_nm)  # batch implemented
            tp_ix, tp_match_ix, tp_ix_bool, tp_match_ix_bool = self._match_kernel(
                out_f.xyz_nm, tar_f.xyz_nm, filter_mask)  # non batch impl.

            tpl.append(out_f[tp_ix])
            tpml.append(tar_f[tp_match_ix])
            fpl.append(out_f[~tp_ix_bool])
            fnl.append(tar_f[~tp_match_ix_bool])
        """Concat them back"""
        tp = emitter.EmitterSet.cat(tpl)
        fp = emitter.EmitterSet.cat(fpl)
        fn = emitter.EmitterSet.cat(fnl)
        tp_match = emitter.EmitterSet.cat(tpml)
        """Let tp and tp_match share the same id's. IDs of ground truth are copied to true positives."""
        if (tp_match.id == -1).all().item():
            tp_match.id = torch.arange(len(tp_match)).type(tp_match.id.dtype)

        tp.id = tp_match.id.type(tp.id.dtype)

        return self._return_match(tp=tp, fp=fp, fn=fn, tp_match=tp_match)
    def test_to_dict(self):

        em = RandomEmitterSet(100, xy_unit='nm', px_size=(100., 200.))
        """Check whether doing one round of to_dict and back works"""
        em_clone = em.clone()

        em_dict = EmitterSet(**em.to_dict())
        assert em_clone == em_dict
    def test_save_load(self, format, tmpdir):

        em = RandomEmitterSet(1000, xy_unit='nm', px_size=(100., 100.))

        p = Path(tmpdir / f'em{format}')
        em.save(p)
        em_load = EmitterSet.load(p)
        assert em == em_load, "Reloaded emitterset is not equivalent to inital one."
def em3d():
    """Most basic (i.e. all necessary fields) 3D EmitterSet"""

    frames = torch.arange(25, dtype=torch.long)
    frames[[0, 1, 2]] = 1
    return EmitterSet(xyz=torch.rand((25, 3)),
                      phot=torch.rand(25),
                      frame_ix=frames)
Beispiel #8
0
    def forward(self, features: torch.Tensor):
        """
        Forward the feature map through the post processing and return an EmitterSet or a list of EmitterSets.
        For the input features we use the following convention:

            0 - Detection channel

            1 - Photon channel

            2 - 'x' channel

            3 - 'y' channel

            4 - 'z' channel

            5 - Background channel

        Expecting x and y channels in nano-metres.

        Args:
            features (torch.Tensor): Features of size :math:`(N, C, H, W)`

        Returns:
            EmitterSet or list of EmitterSets: Specified by return_format argument, EmitterSet in nano metres.

        """

        if self.skip_if(features):
            return EmptyEmitterSet(xy_unit=self.xy_unit, px_size=self.px_size)

        if features.dim() != 4:
            raise ValueError(
                "Wrong dimensionality. Needs to be N x C x H x W.")

        features = features[:, self.
                            pphotxyzbg_mapping]  # change channel order if needed

        p = features[:, [0], :, :]
        features = features[:, 1:, :, :]  # phot, x, y, z, bg

        p_out, feat_out = self._forward_raw_impl(p, features)

        feature_list, prob_final, frame_ix = self._frame2emitter(
            p_out, feat_out)
        frame_ix = frame_ix.squeeze()

        return EmitterSet(xyz=feature_list[:, 1:4],
                          phot=feature_list[:, 0],
                          frame_ix=frame_ix,
                          prob=prob_final,
                          bg=feature_list[:, 4],
                          xy_unit=self.xy_unit,
                          px_size=self.px_size)
    def test_cat_emittersets(self):

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, None, 1)
        assert 70 == len(cat_sets)
        assert 0 == cat_sets.frame_ix[0]
        assert 1 == cat_sets.frame_ix[50]

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, torch.tensor([5, 50]), None)
        assert 70 == len(cat_sets)
        assert 5 == cat_sets.frame_ix[0]
        assert 50 == cat_sets.frame_ix[50]

        # test correctness of px size and xy unit
        sets = [
            RandomEmitterSet(50, xy_unit='px', px_size=(100., 200.)),
            RandomEmitterSet(20)
        ]
        em = EmitterSet.cat(sets)
        assert em.xy_unit == 'px'
        assert (em.px_size == torch.tensor([100., 200.])).all()
def em3d_full(em3d):
    return EmitterSet(xyz=em3d.xyz,
                      phot=em3d.phot,
                      bg=torch.rand_like(em3d.phot) * 100,
                      frame_ix=em3d.frame_ix,
                      id=em3d.id,
                      xyz_sig=torch.rand_like(em3d.xyz),
                      phot_sig=torch.rand_like(em3d.phot) * em3d.phot.sqrt(),
                      xyz_cr=torch.rand_like(em3d.xyz)**2,
                      phot_cr=torch.rand_like(em3d.phot) * em3d.phot.sqrt() *
                      1.5,
                      bg_cr=torch.rand_like(em3d.phot),
                      xy_unit='nm',
                      px_size=(100., 200.))
    def test_chunk(self):

        big_em = RandomEmitterSet(100000)

        splits = big_em.chunks(10000)
        re_merged = EmitterSet.cat(splits)

        assert sum([len(e) for e in splits]) == len(big_em)
        assert re_merged == big_em

        # test not evenly splittable number
        em = RandomEmitterSet(7)
        splits = em.chunks(3)

        assert len(splits[0]) == 3
        assert len(splits[1]) == 2
        assert len(splits[-1]) == 2
Beispiel #12
0
def transform_emitter(em: emitter.EmitterSet,
                      trafo: dict) -> emitter.EmitterSet:
    """
    Transform a set of emitters specified by a transformation dictionary. Returns transformed emitterset.

    Args:
        em: emitterset to be transformed
        trafo: transformation specs

    """

    mod_em = em.clone()
    """Set Px Size"""
    mod_em.px_size = torch.tensor(
        trafo['px_size']) if trafo['px_size'] is not None else mod_em.px_size
    """Modify proper attributes"""
    if trafo['xyz_axis'] is not None:
        mod_em.xyz = mod_em.xyz[:, trafo['xyz_axis']]
        mod_em.xyz_cr = mod_em.xyz_cr[:, trafo['xyz_axis']]
        mod_em.xyz_sig = mod_em.xyz_sig[:, trafo['xyz_axis']]

    if trafo['xyz_nm_factor'] is not None:
        mod_em.xyz_nm *= torch.tensor(trafo['xyz_nm_factor'])

    if trafo['xyz_nm_shift'] is not None:
        mod_em.xyz_nm += torch.tensor(trafo['xyz_nm_shift'])

    if trafo['xyz_px_factor'] is not None:
        mod_em.xyz_px *= torch.tensor(trafo['xyz_px_factor'])

    if trafo['xyz_px_shift'] is not None:
        mod_em.xyz_px += torch.tensor(trafo['xyz_px_shift'])

    if trafo['frame_ix_shift'] is not None:
        mod_em.frame_ix += torch.tensor(trafo['frame_ix_shift'])
    """Modify unit in which emitter is stored and possibly set px size'"""
    if trafo['xy_unit'] is not None:
        if trafo['xy_unit'] == 'nm':
            mod_em.xyz_nm = mod_em.xyz_nm
        elif trafo['xy_unit'] == 'px':
            mod_em.xyz_px = mod_em.xyz_px
        else:
            raise ValueError(f"Unsupported unit ({trafo['xy_unit']}).")

    return mod_em
    def test_split_cat(self):
        """
        Tests whether split and cat (and sort by ID) returns the same result as the original starting.

        """

        em = RandomEmitterSet(1000)
        em.id = torch.arange(len(em))
        em.frame_ix = torch.randint_like(em.frame_ix, 10000)
        """Run"""
        em_split = em.split_in_frames(0, 9999)
        em_re_merged = EmitterSet.cat(em_split)
        """Assertions"""
        # sort both by id
        ix = torch.argsort(em.id)
        ix_re = torch.argsort(em_re_merged.id)

        assert em[ix] == em_re_merged[ix_re]
Beispiel #14
0
    def forward(self, x: torch.Tensor) -> EmitterSet:
        """
        Forward model output tensor through post-processing and return EmitterSet. Will include sigma values in
        EmitterSet if mapping was provided initially.

        Args:
            x: model output

        Returns:
            EmitterSet

        """
        """Reorder features channel-wise."""
        x_mapped = x[:, self.pphotxyzbg_mapping]
        """Filter"""
        active_px = self._filter(x_mapped[:,
                                          0])  # 0th ch. is detection channel
        prob = x_mapped[:, 0][active_px]
        """Look-Up in channels"""
        frame_ix, features = self._lookup_features(x_mapped[:, 1:], active_px)
        """Return EmitterSet"""
        xyz = features[1:4].transpose(0, 1)
        """If sigma mapping is present, get those values as well."""
        if self.photxyz_sigma_mapping is not None:
            sigma = x[:, self.photxyz_sigma_mapping]
            _, features_sigma = self._lookup_features(sigma, active_px)

            xyz_sigma = features_sigma[1:4].transpose(0, 1).cpu()
            phot_sigma = features_sigma[0].cpu()
        else:
            xyz_sigma = None
            phot_sigma = None

        return EmitterSet(
            xyz=xyz.cpu(),
            frame_ix=frame_ix.cpu(),
            phot=features[0, :].cpu(),
            xyz_sig=xyz_sigma,
            phot_sig=phot_sigma,
            bg_sig=None,
            bg=features[4, :].cpu() if features.size(0) == 5 else None,
            prob=prob.cpu(),
            xy_unit=self.xy_unit,
            px_size=self.px_size)
    def test_adjacent_frame_split(self):
        xyz = torch.rand((500, 3))
        phot = torch.rand_like(xyz[:, 0])
        frame_ix = torch.randint_like(xyz[:, 0], low=-1, high=2).int()
        em = EmitterSet(xyz, phot, frame_ix)

        em_split = em.split_in_frames(-1, 1)
        assert (em_split[0].frame_ix == -1).all()
        assert (em_split[1].frame_ix == 0).all()
        assert (em_split[2].frame_ix == 1).all()

        em_split = em.split_in_frames(0, 0)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == 0).all()

        em_split = em.split_in_frames(-1, -1)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == -1).all()

        em_split = em.split_in_frames(1, 1)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == 1).all()
def em2d():
    """Effectively 2D EmitterSet"""

    return EmitterSet(xyz=torch.rand((25, 2)),
                      phot=torch.rand(25),
                      frame_ix=torch.zeros(25, dtype=torch.long))