示例#1
0
    def forward(self, x: torch.Tensor):
        """

        Args:
            x (torch.Tensor): any input tensor where the first dim is the batch-dim.

        Returns:
            EmptyEmitterSet: An empty EmitterSet

        """

        return EmptyEmitterSet(xy_unit=self.xy_unit, px_size=self.px_size)
示例#2
0
    def forward(self, features: torch.Tensor):
        """
        Forward the feature map through the post processing and return an EmitterSet or a list of EmitterSets.
        For the input features we use the following convention:

            0 - Detection channel

            1 - Photon channel

            2 - 'x' channel

            3 - 'y' channel

            4 - 'z' channel

            5 - Background channel

        Expecting x and y channels in nano-metres.

        Args:
            features (torch.Tensor): Features of size :math:`(N, C, H, W)`

        Returns:
            EmitterSet or list of EmitterSets: Specified by return_format argument, EmitterSet in nano metres.

        """

        if self.skip_if(features):
            return EmptyEmitterSet(xy_unit=self.xy_unit, px_size=self.px_size)

        if features.dim() != 4:
            raise ValueError(
                "Wrong dimensionality. Needs to be N x C x H x W.")

        features = features[:, self.
                            pphotxyzbg_mapping]  # change channel order if needed

        p = features[:, [0], :, :]
        features = features[:, 1:, :, :]  # phot, x, y, z, bg

        p_out, feat_out = self._forward_raw_impl(p, features)

        feature_list, prob_final, frame_ix = self._frame2emitter(
            p_out, feat_out)
        frame_ix = frame_ix.squeeze()

        return EmitterSet(xyz=feature_list[:, 1:4],
                          phot=feature_list[:, 0],
                          frame_ix=frame_ix,
                          prob=prob_final,
                          bg=feature_list[:, 4],
                          xy_unit=self.xy_unit,
                          px_size=self.px_size)
示例#3
0
def test_empty_emitterset():
    em = EmptyEmitterSet()
    assert 0 == len(em)