def forward(self, em: EmitterSet, bg: torch.Tensor = None, ix_low: int = None, ix_high: int = None): em, ix_low, ix_high = self._filter_forward(em, ix_low, ix_high) n_frames = ix_high - ix_low + 1 """Setup and compute parameter target (i.e. a matrix / tensor in which all params are concatenated).""" param_tar = torch.zeros((n_frames, self.n_max, 4)) mask_tar = torch.zeros((n_frames, self.n_max)).bool() if self.xy_unit == 'px': xyz = em.xyz_px elif self.xy_unit == 'nm': xyz = em.xyz_nm else: raise NotImplementedError """Set number of active elements per frame""" for i in range(n_frames): n_emitter = len(em.get_subset_frame(i, i)) if n_emitter > self.n_max: raise ValueError("Number of actual emitters exceeds number of max. emitters.") mask_tar[i, :n_emitter] = 1 ix = em.frame_ix == i param_tar[i, :n_emitter, 0] = em[ix].phot param_tar[i, :n_emitter, 1:] = xyz[ix] return self._postprocess_output(param_tar), self._postprocess_output(mask_tar), bg
def _filter_forward(self, em: EmitterSet, ix_low: (int, None), ix_high: (int, None)): """ Filter emitters and auto-set frame bounds Args: em: ix_low: ix_high: Returns: em (EmitterSet): filtered EmitterSet ix_low (int): lower frame index ix_high (int): upper frame index """ if ix_low is None: ix_low = self.ix_low if ix_high is None: ix_high = self.ix_high """Limit the emitters to the frames of interest and shift the frame index to start at 0.""" em = em.get_subset_frame(ix_low, ix_high, -ix_low) return em, ix_low, ix_high
def test_forward_different_impl(self, targ): """ Test the implementation with a slow for loop Args: targ: Returns: """ """Setup""" n = 50000 xyz = torch.rand(n, 3) * 100 phot = torch.rand_like(xyz[:, 0]) frame_ix = torch.arange(n) em = EmitterSet(xyz, phot, frame_ix, xy_unit='px') """Run""" out = targ.forward(em, None, 0, n - 1) """Assert""" non_zero_detect = out[:, [0]].nonzero() for i in range(non_zero_detect.size(0)): for x in range(-(targ._roi_size - 1) // 2, (targ._roi_size - 1) // 2 + 1): for y in range(-(targ._roi_size - 1) // 2, (targ._roi_size - 1) // 2 + 1): ix_n = non_zero_detect[i, 0] ix_x = torch.clamp(non_zero_detect[i, -2] + x, 0, 63) ix_y = torch.clamp(non_zero_detect[i, -1] + y, 0, 63) assert out[ ix_n, 2, ix_x, ix_y] != 0 # would only fail if either x or y are exactly % 1 == 0
def test_forward(self, targ): """Setup""" em = EmitterSet(xyz=torch.tensor([[0., 0., 0.], [0.49, 0., 0.], [0., 0.49, 0.], [0.49, 0.49, 0.]]), phot=torch.ones(4), frame_ix=torch.tensor([0, 1, 2, 3]), xy_unit='px') """Run""" tar_out = targ.forward(em, None) """Assert""" assert tar_out.size() == torch.Size([6, 20, 64, 64]) # Negative samples assert tar_out[1, 0, 0, 0] == 0. # Positive Samples assert (tar_out[[0, 1, 2, 3], [0, 5, 10, 15], 0, 0] == torch.tensor([1., 1., 1., 1.])).all()
def test_forward_statistical(self, targ): """Setup""" n = 1000 xyz = torch.zeros((n, 3)) xyz[:, 0] = torch.linspace(-10, 78., n) xyz[:, 1] = 30. frame_ix = torch.arange(n) em = EmitterSet(xyz, torch.ones_like(xyz[:, 0]), frame_ix, xy_unit='px') """Run""" out = targ.forward(em, None, 0, n - 1) """Assert""" assert (out[:, 0, :, 29] == 0).all() assert (out[:, 0, :, 31] == 0).all() assert (out[(xyz[:, 0] < -0.5) * (xyz[:, 0] >= 63.5)] == 0).all() assert ( out.nonzero()[:, 0].unique() == frame_ix[(xyz[:, 0] >= -0.5) * (xyz[:, 0] < 63.5)]).all()
def fem(self): return EmitterSet(xyz=torch.tensor([[0., 0., 0.]]), phot=torch.Tensor([1.]), frame_ix=torch.tensor([0]), xy_unit='px')
def fem(self): return EmitterSet(xyz=torch.tensor([[1., 2., 3.], [4., 5., 6.]]), phot=torch.Tensor([3., 2.]), frame_ix=torch.tensor([0, 1]), xy_unit='px')