示例#1
0
    def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None):
        em, ix_low, ix_high = self._filter_forward(em, ix_low, ix_high)

        n_frames = ix_high - ix_low + 1

        """Setup and compute parameter target (i.e. a matrix / tensor in which all params are concatenated)."""
        param_tar = torch.zeros((n_frames, self.n_max, 4))
        mask_tar = torch.zeros((n_frames, self.n_max)).bool()

        if self.xy_unit == 'px':
            xyz = em.xyz_px
        elif self.xy_unit == 'nm':
            xyz = em.xyz_nm
        else:
            raise NotImplementedError

        """Set number of active elements per frame"""
        for i in range(n_frames):
            n_emitter = len(em.get_subset_frame(i, i))

            if n_emitter > self.n_max:
                raise ValueError("Number of actual emitters exceeds number of max. emitters.")

            mask_tar[i, :n_emitter] = 1

            ix = em.frame_ix == i
            param_tar[i, :n_emitter, 0] = em[ix].phot
            param_tar[i, :n_emitter, 1:] = xyz[ix]

        return self._postprocess_output(param_tar), self._postprocess_output(mask_tar), bg
示例#2
0
    def _filter_forward(self, em: EmitterSet, ix_low: (int, None), ix_high: (int, None)):
        """
        Filter emitters and auto-set frame bounds

        Args:
            em:
            ix_low:
            ix_high:

        Returns:
            em (EmitterSet): filtered EmitterSet
            ix_low (int): lower frame index
            ix_high (int): upper frame index

        """

        if ix_low is None:
            ix_low = self.ix_low
        if ix_high is None:
            ix_high = self.ix_high

        """Limit the emitters to the frames of interest and shift the frame index to start at 0."""
        em = em.get_subset_frame(ix_low, ix_high, -ix_low)

        return em, ix_low, ix_high
示例#3
0
    def test_forward_different_impl(self, targ):
        """
        Test the implementation with a slow for loop

        Args:
            targ:

        Returns:

        """
        """Setup"""
        n = 50000
        xyz = torch.rand(n, 3) * 100
        phot = torch.rand_like(xyz[:, 0])
        frame_ix = torch.arange(n)

        em = EmitterSet(xyz, phot, frame_ix, xy_unit='px')
        """Run"""
        out = targ.forward(em, None, 0, n - 1)
        """Assert"""
        non_zero_detect = out[:, [0]].nonzero()

        for i in range(non_zero_detect.size(0)):
            for x in range(-(targ._roi_size - 1) // 2,
                           (targ._roi_size - 1) // 2 + 1):
                for y in range(-(targ._roi_size - 1) // 2,
                               (targ._roi_size - 1) // 2 + 1):
                    ix_n = non_zero_detect[i, 0]
                    ix_x = torch.clamp(non_zero_detect[i, -2] + x, 0, 63)
                    ix_y = torch.clamp(non_zero_detect[i, -1] + y, 0, 63)
                    assert out[
                        ix_n, 2, ix_x,
                        ix_y] != 0  # would only fail if either x or y are exactly % 1 == 0
示例#4
0
 def test_forward(self, targ):
     """Setup"""
     em = EmitterSet(xyz=torch.tensor([[0., 0., 0.], [0.49, 0., 0.],
                                       [0., 0.49, 0.], [0.49, 0.49, 0.]]),
                     phot=torch.ones(4),
                     frame_ix=torch.tensor([0, 1, 2, 3]),
                     xy_unit='px')
     """Run"""
     tar_out = targ.forward(em, None)
     """Assert"""
     assert tar_out.size() == torch.Size([6, 20, 64, 64])
     # Negative samples
     assert tar_out[1, 0, 0, 0] == 0.
     # Positive Samples
     assert (tar_out[[0, 1, 2, 3], [0, 5, 10, 15], 0,
                     0] == torch.tensor([1., 1., 1., 1.])).all()
示例#5
0
    def test_forward_statistical(self, targ):
        """Setup"""
        n = 1000

        xyz = torch.zeros((n, 3))
        xyz[:, 0] = torch.linspace(-10, 78., n)
        xyz[:, 1] = 30.

        frame_ix = torch.arange(n)

        em = EmitterSet(xyz,
                        torch.ones_like(xyz[:, 0]),
                        frame_ix,
                        xy_unit='px')
        """Run"""
        out = targ.forward(em, None, 0, n - 1)
        """Assert"""
        assert (out[:, 0, :, 29] == 0).all()
        assert (out[:, 0, :, 31] == 0).all()
        assert (out[(xyz[:, 0] < -0.5) * (xyz[:, 0] >= 63.5)] == 0).all()
        assert (
            out.nonzero()[:,
                          0].unique() == frame_ix[(xyz[:, 0] >= -0.5) *
                                                  (xyz[:, 0] < 63.5)]).all()
示例#6
0
 def fem(self):
     return EmitterSet(xyz=torch.tensor([[0., 0., 0.]]),
                       phot=torch.Tensor([1.]),
                       frame_ix=torch.tensor([0]),
                       xy_unit='px')
示例#7
0
 def fem(self):
     return EmitterSet(xyz=torch.tensor([[1., 2., 3.], [4., 5., 6.]]),
                       phot=torch.Tensor([3., 2.]),
                       frame_ix=torch.tensor([0, 1]),
                       xy_unit='px')