Ejemplo n.º 1
0
    def test_result_dictablet(self, evaluator):

        result = evaluator.forward(em.RandomEmitterSet(10, xy_unit='nm'),
                                   em.RandomEmitterSet(1, xy_unit='nm'),
                                   em.RandomEmitterSet(2, xy_unit='nm'),
                                   em.RandomEmitterSet(10, xy_unit='nm'))

        assert isinstance(result._asdict(), dict)
        assert result.prec == result._asdict()['prec']
Ejemplo n.º 2
0
class TestSegmentationEval(TestEval):
    @pytest.fixture()
    def evaluator(self):
        return evaluation.SegmentationEvaluation()

    test_data = [(em.EmptyEmitterSet(), em.EmptyEmitterSet(),
                  em.EmptyEmitterSet(), (float('nan'), ) * 4),
                 (em.EmptyEmitterSet(), em.RandomEmitterSet(1),
                  em.EmptyEmitterSet(), (0., float('nan'), 0., float('nan')))]

    @pytest.mark.parametrize("tp,fp,fn,expect", test_data)
    def test_segmentation(self, evaluator, tp, fp, fn, expect):
        """
        Some handcrafted values

        Args:
            tp: true positives
            fp: false positives
            fn: false negatives
            expect(tuple): expected outcome

        """
        out = evaluator.forward(tp, fp, fn)

        for o, e in zip(out, expect):  # check all the outcomes
            if math.isnan(o) or math.isnan(e):
                assert math.isnan(o)
                assert math.isnan(e)

            else:
                assert o == e
Ejemplo n.º 3
0
 def test_sigma_filter(self, frac):
     """Setup"""
     em = emitter.RandomEmitterSet(10000)
     em.xyz_sig = (torch.randn_like(em.xyz_sig) + 5).clamp(0.)
     """Run"""
     out = em.filter_by_sigma(fraction=frac)
     """Assert"""
     assert len(em) * frac == pytest.approx(len(out))
Ejemplo n.º 4
0
class TestBgPerEmitterFromBgFrame:
    @pytest.fixture(scope='class')
    def extractor(self):
        return background.BgPerEmitterFromBgFrame(17, (-0.5, 63.5),
                                                  (-0.5, 63.5), (64, 64))

    def test_mean_filter(self, extractor):
        """
        Args:
            extractor: fixture as above

        """
        """Some hard coded setups"""
        x_in = []
        x_in.append(torch.randn((1, 1, 64, 64)))
        x_in.append(torch.zeros((1, 1, 64, 64)))
        x_in.append(
            torch.meshgrid(
                torch.arange(64),
                torch.arange(64))[0].unsqueeze(0).unsqueeze(0).float())

        # excpt outcome
        expect = []
        expect.append(torch.zeros_like(x_in[0]))
        expect.append(torch.zeros_like(x_in[0]))
        expect.append(8)
        """Run"""
        out = []
        for x in x_in:
            out.append(extractor._mean_filter(x))
        """Assertions"""
        assert test_utils.tens_almeq(out[0], expect[0], 1)  # 10 sigma
        assert test_utils.tens_almeq(out[1], expect[1])
        assert test_utils.tens_almeq(out[2][0, 0, 8, :],
                                     8 * torch.ones_like(out[2][0, 0, 0, :]),
                                     1e-4)

    test_data = [
        (torch.zeros(
            (1, 1, 64, 64)), emitter.RandomEmitterSet(100), torch.zeros(
                (100, ))),
        (torch.meshgrid(torch.arange(64),
                        torch.arange(64))[0].unsqueeze(0).unsqueeze(0).float(),
         emitter.CoordinateOnlyEmitter(torch.tensor([[8., 0., 0.]])),
         torch.tensor([8.])),
        (torch.rand((1, 1, 64, 64)),
         emitter.CoordinateOnlyEmitter(torch.tensor([[70., 32., 0.]])),
         torch.tensor([float('nan')]))
    ]

    @pytest.mark.parametrize("bg,em,expect_bg", test_data)
    def test_forward(self, extractor, bg, em, expect_bg):
        """Run"""
        out = extractor.forward(em, bg)
        """Assertions"""
        assert test_utils.tens_almeq(out.bg, expect_bg, 1e-4, nan=True)
Ejemplo n.º 5
0
    def test_distance_excpt(self, evaluator):
        """

        Args:
            evaluator:

        """
        with pytest.raises(ValueError):
            evaluator.forward(em.EmptyEmitterSet('nm'),
                              em.RandomEmitterSet(1, xy_unit='nm'))
Ejemplo n.º 6
0
 def test_forward(self, sim, ix_low, ix_high, n):
     """Tests the output length of forward method of simulation."""
     """Setup"""
     sim.frame_range = (None, None)
     em = emitter.RandomEmitterSet(2)
     em.frame_ix = torch.tensor([-2, 3]).long()
     """Run"""
     frames, bg_frames = sim.forward(em, ix_low=ix_low, ix_high=ix_high)
     """Assert"""
     assert len(frames) == n, "Wrong number of frames."
Ejemplo n.º 7
0
def test_streamer(last_index, tmpdir):

    stream = emitter_io.EmitterWriteStream('dummy',
                                           '.pt',
                                           tmpdir,
                                           last_index=last_index)

    with mock.patch.object(emitter.EmitterSet, 'save') as mock_save:
        stream.write(emitter.RandomEmitterSet(20), 0, 100)

    if last_index == 'including':
        mock_save.assert_called_once_with(tmpdir / 'dummy_0_100.pt')
    elif last_index == 'excluding':
        mock_save.assert_called_once_with(tmpdir / 'dummy_0_99.pt')

    with mock.patch.object(emitter.EmitterSet, 'save') as mock_save:
        stream(emitter.RandomEmitterSet(20), 0, 100)

    mock_save.assert_called_once()
Ejemplo n.º 8
0
    def test_hist_detection(self):

        em = emitter.RandomEmitterSet(10000)
        em.prob = torch.rand_like(em.prob)
        em.xyz_sig = torch.randn_like(em.xyz_sig) * torch.tensor(
            [1., 2., 3.]).unsqueeze(0)
        """Run"""
        out = em.hist_detection()
        """Assert"""
        assert set(out.keys()) == {'prob', 'sigma_x', 'sigma_y', 'sigma_z'}
Ejemplo n.º 9
0
class TestWeightedErrors(TestEval):
    @pytest.fixture(params=['phot', 'crlb'])
    def evaluator(self, request):
        return evaluation.WeightedErrors(mode=request.param, reduction='mstd')

    # one mode of paremtr. should not lead to an error because than the reduction type is also checked
    @pytest.mark.parametrize("mode", [None, 'abc', 'phot'])
    @pytest.mark.parametrize("reduction", ['None', 'abc'])
    def test_sanity(self, evaluator, mode, reduction):
        """Assertions"""
        with pytest.raises(ValueError):
            evaluator.__init__(mode=mode, reduction=reduction)

    def test_forward_handcrafted(self, evaluator):
        # if evaluator.mode != 'phot':
        #     return
        """Setup"""
        tp = em.EmitterSet(xyz=torch.zeros((4, 3)),
                           phot=torch.tensor([1050., 1950., 3050., 4050]),
                           frame_ix=torch.tensor([0, 0, 1, 2]),
                           bg=torch.ones((4, )) * 10,
                           xy_unit='px',
                           px_size=(127., 117.))

        ref = tp.clone()
        ref.xyz += 0.5
        ref.phot = torch.tensor([1000., 2000., 3000., 4000.])
        ref.xyz_cr = (torch.tensor([[10., 10., 15], [8., 8., 10], [6., 6., 7],
                                    [4., 4., 5.]]) / 100.)**2
        ref.phot_cr = torch.tensor([10., 12., 14., 16.])**2
        ref.bg_cr = torch.tensor([1., 2., 3., 4])**2
        """Run"""
        _, _, _, dpos, dphot, dbg = evaluator.forward(
            tp, ref)  # test only on non reduced values
        """Assertions"""
        assert (dpos.abs().argsort(0) == torch.arange(4).unsqueeze(1).repeat(1, 3)).all(), "Weighted error for pos." \
                                                                                           "should be monot. increasing"

        assert (dphot.abs().argsort(descending=True) == torch.arange(4)).all(), "Weighted error for photon should be " \
                                                                                "monot. decreasing"

        assert (dbg.abs().argsort(descending=True) == torch.arange(4)).all(), "Weighted error for background should be " \
                                                                                "monot. decreasing"

    data_forward_sanity = [
        (em.EmptyEmitterSet(xy_unit='nm'), em.EmptyEmitterSet(xy_unit='nm'),
         False, (torch.empty((0, 3)), torch.empty((0, )), torch.empty((0, )))),
        (em.RandomEmitterSet(5), em.EmptyEmitterSet(), True, None)
    ]

    @pytest.mark.parametrize("tp,ref,expt_err,expt_out", data_forward_sanity)
    def test_forward_sanity(self, evaluator, tp, ref, expt_err, expt_out):
        """
        General forward sanity checks.
            1. Both empty sets of emitters
            2. Unequal size

        """

        if expt_err and expt_out is not None:
            raise RuntimeError("Wrong test setup.")
        """Run"""
        if expt_err:
            with pytest.raises(ValueError):
                _ = evaluator.forward(tp, ref)
            return

        else:
            out = evaluator.forward(tp, ref)
        """Assertions"""
        assert isinstance(out, evaluator._return), "Wrong output type"
        for out_i, expt_i in zip(
                out[3:], expt_out):  # test only the non reduced outputs
            assert test_utils.tens_almeq(out_i, expt_i, 1e-4)

    def test_reduction(self, evaluator):
        """

        Args:
            evaluator:

        """
        """Setup, Run and Test"""

        # mean and std
        dxyz, dphot, dbg = torch.randn(
            (250000, 3)), torch.randn(250000) + 20, torch.rand(250000)
        dxyz_, dphot_, dbg_ = evaluator._reduce(dxyz, dphot, dbg, 'mstd')

        assert test_utils.tens_almeq(dxyz_[0], torch.zeros((3, )), 1e-2)
        assert test_utils.tens_almeq(dxyz_[1], torch.ones((3, )), 1e-2)

        assert test_utils.tens_almeq(dphot_[0], torch.zeros((1, )) + 20, 1e-2)
        assert test_utils.tens_almeq(dphot_[1], torch.ones((1, )), 1e-2)

        assert test_utils.tens_almeq(dbg_[0], torch.zeros((1, )) + 0.5, 1e-2)
        assert test_utils.tens_almeq(dbg_[1], torch.ones((1, )) * 0.2889, 1e-2)

        # gaussian fit
        dxyz, dphot, dbg = torch.randn(
            (250000, 3)), torch.randn(250000) + 20, torch.randn(250000)
        dxyz_, dphot_, dbg_ = evaluator._reduce(dxyz, dphot, dbg, 'gaussian')

        assert test_utils.tens_almeq(dxyz_[0], torch.zeros((3, )), 1e-2)
        assert test_utils.tens_almeq(dxyz_[1], torch.ones((3, )), 1e-2)

        assert test_utils.tens_almeq(dphot_[0], torch.zeros((1, )) + 20, 1e-2)
        assert test_utils.tens_almeq(dphot_[1], torch.ones((1, )), 1e-2)

        assert test_utils.tens_almeq(dbg_[0], torch.zeros((1, )), 1e-2)
        assert test_utils.tens_almeq(dbg_[1], torch.ones((1, )), 1e-2)

    plot_test_data = [(torch.empty((0, 3)), torch.empty(
        (0, 3)), torch.empty((0, 3))),
                      (torch.randn(
                          (25000, 3)), torch.randn(25000), torch.randn(25000))]

    plot_test_axes = [None, plt.subplots(5)[1]]

    @pytest.mark.plot
    @pytest.mark.parametrize("dxyz,dphot,dbg", plot_test_data)
    @pytest.mark.parametrize("axes", plot_test_axes)
    def test_plot_hist(self, evaluator, dxyz, dphot, dbg, axes):
        """Run"""
        axes = evaluator.plot_error(dxyz, dphot, dbg, axes=axes)
        """Assert"""
        plt.show()
Ejemplo n.º 10
0
 def sample(self):
     return em.RandomEmitterSet(10)
Ejemplo n.º 11
0
    def test_iadd(self):
        em_0 = emitter.RandomEmitterSet(20)
        em_1 = emitter.RandomEmitterSet(50)

        em_0 += em_1
        assert len(em_0) == 70
Ejemplo n.º 12
0
 def em(self):
     return emitter.RandomEmitterSet(10)
Ejemplo n.º 13
0
def em_rand():
    return emitter.RandomEmitterSet(20, xy_unit='px', px_size=(100, 200))
Ejemplo n.º 14
0
class TestEmitterSet:
    def test_properties(self, em2d, em3d, em3d_full):

        for em in (em2d, em3d, em3d_full):
            em.phot_scr
            em.bg_scr

            if em.px_size is not None and em.xy_unit is not None:
                em.xyz_px
                em.xyz_nm
                em.xyz_scr
                em.xyz_scr_px
                em.xyz_scr_nm
                em.xyz_sig_px
                em.xyz_sig_nm
                em.xyz_sig_tot_nm
                em.xyz_sig_weighted_tot_nm

        # ToDo: Test auto conversion

    def test_dim(self, em2d, em3d):

        assert em2d.dim() == 2
        assert em3d.dim() == 3

    def test_xyz_shape(self, em2d, em3d):
        """
        Tests shape and correct data type
        Args:
            em2d: fixture (see above)
            em3d: fixture (see above)

        Returns:

        """

        # 2D input get's converted to 3D with zeros
        assert em2d.xyz.shape[1] == 3
        assert em3d.xyz.shape[1] == 3

        assert em3d.frame_ix.dtype in (torch.int, torch.long, torch.short)

    xyz_conversion_data = [  # xyz_input, # xy_unit, #px-size # expect px, # expect nm
        (torch.empty((0, 3)), None, None, "err", "err"),
        (torch.empty((0, 3)), 'px', None, torch.empty((0, 3)), "err"),
        (torch.empty((0, 3)), 'nm', None, "err", torch.empty((0, 3))),
        (torch.tensor([[25., 25., 5.]]), None, None, "err", "err"),
        (torch.tensor([[25., 25.,
                        5.]]), 'px', None, torch.tensor([[25., 25.,
                                                          5.]]), "err"),
        (torch.tensor([[25., 25.,
                        5.]]), 'nm', None, "err", torch.tensor([[25., 25.,
                                                                 5.]])),
        (torch.tensor([[.25, .25, 5.]]), 'px', (50., 100.),
         torch.tensor([[.25, .25, 5.]]), torch.tensor([[12.5, 25., 5.]])),
        (torch.tensor([[25., 25., 5.]]), 'nm', (50., 100.),
         torch.tensor([[.5, .25, 5.]]), torch.tensor([[25., 25., 5.]]))
    ]

    @pytest.mark.parametrize("xyz_input,xy_unit,px_size,expct_px,expct_nm",
                             xyz_conversion_data)
    @pytest.mark.filterwarnings("ignore:UserWarning")
    def test_xyz_conversion(self, xyz_input, xy_unit, px_size, expct_px,
                            expct_nm):
        """Init and expect warning if specified"""
        em = emitter.CoordinateOnlyEmitter(xyz_input,
                                           xy_unit=xy_unit,
                                           px_size=px_size)
        """Test the respective units"""
        if isinstance(expct_px, str) and expct_px == "err":
            with pytest.raises(ValueError):
                _ = em.xyz_px
        else:
            assert test_utils.tens_almeq(em.xyz_px, expct_px)

        if isinstance(expct_nm, str) and expct_nm == "err":
            with pytest.raises(ValueError):
                _ = em.xyz_nm

        else:
            assert test_utils.tens_almeq(em.xyz_nm, expct_nm)

    xyz_cr_conversion_data = [  # xyz_scr_input, # xy_unit, #px-size # expect_scr_px, # expect scr_nm
        (torch.empty((0, 3)), None, None, "err", "err"),
        (torch.empty((0, 3)), 'px', None, torch.empty((0, 3)), "err"),
        (torch.empty((0, 3)), 'nm', None, "err", torch.empty((0, 3))),
        (torch.tensor([[25., 25., 5.]]), None, None, "err", "err"),
        (torch.tensor([[25., 25.,
                        5.]]), 'px', None, torch.tensor([[25., 25.,
                                                          5.]]), "err"),
        (torch.tensor([[25., 25.,
                        5.]]), 'nm', None, "err", torch.tensor([[25., 25.,
                                                                 5.]])),
        (torch.tensor([[.25, .25, 5.]]), 'px', (50., 100.),
         torch.tensor([[.25, .25, 5.]]), torch.tensor([[12.5, 25., 5.]])),
        (torch.tensor([[25., 25., 5.]]), 'nm', (50., 100.),
         torch.tensor([[.5, .25, 5.]]), torch.tensor([[25., 25., 5.]]))
    ]

    @pytest.mark.parametrize("xyz_scr_input,xy_unit,px_size,expct_px,expct_nm",
                             xyz_cr_conversion_data)
    @pytest.mark.filterwarnings("ignore:UserWarning")
    def test_xyz_cr_conversion(self, xyz_scr_input, xy_unit, px_size, expct_px,
                               expct_nm):
        """
        Here we test the cramer rao unit conversion. We can reuse the testdata as for the xyz conversion because it does
        not make a difference for the test candidate.

        """
        """Init and expect warning if specified"""
        em = emitter.CoordinateOnlyEmitter(torch.rand_like(xyz_scr_input),
                                           xy_unit=xy_unit,
                                           px_size=px_size)
        em.xyz_cr = xyz_scr_input**2
        """Test the respective units"""
        if isinstance(expct_px, str) and expct_px == "err":
            with pytest.raises(ValueError):
                _ = em.xyz_cr_px
        else:
            assert test_utils.tens_almeq(em.xyz_scr_px, expct_px)

        if isinstance(expct_nm, str) and expct_nm == "err":
            with pytest.raises(ValueError):
                _ = em.xyz_cr_nm

        else:
            assert test_utils.tens_almeq(em.xyz_scr_nm, expct_nm)

    @pytest.mark.parametrize("attr,power", [('xyz', 1), ('xyz_sig', 1),
                                            ('xyz_cr', 2)])
    def test_property_conversion(self, attr, power, em3d_full):
        with mock.patch.object(emitter.EmitterSet,
                               '_pxnm_conversion') as conversion:
            getattr(em3d_full, attr + '_nm')

        conversion.assert_called_once_with(getattr(em3d_full, attr),
                                           in_unit='nm',
                                           tar_unit='nm',
                                           power=power)

    @mock.patch.object(emitter.EmitterSet, 'cat')
    def test_add(self, mock_add):
        em_0 = emitter.RandomEmitterSet(20)
        em_1 = emitter.RandomEmitterSet(100)

        _ = em_0 + em_1
        mock_add.assert_called_once_with((em_0, em_1), None, None)

    def test_iadd(self):
        em_0 = emitter.RandomEmitterSet(20)
        em_1 = emitter.RandomEmitterSet(50)

        em_0 += em_1
        assert len(em_0) == 70

    def test_chunk(self):

        big_em = RandomEmitterSet(100000)

        splits = big_em.chunks(10000)
        re_merged = EmitterSet.cat(splits)

        assert sum([len(e) for e in splits]) == len(big_em)
        assert re_merged == big_em

        # test not evenly splittable number
        em = RandomEmitterSet(7)
        splits = em.chunks(3)

        assert len(splits[0]) == 3
        assert len(splits[1]) == 2
        assert len(splits[-1]) == 2

    def test_split_in_frames(self, em2d, em3d):
        splits = em2d.split_in_frames(None, None)
        assert splits.__len__() == 1

        splits = em3d.split_in_frames(None, None)
        assert em3d.frame_ix.max() - em3d.frame_ix.min() + 1 == len(splits)
        """Test negative numbers in Frame ix."""
        neg_frames = EmitterSet(torch.rand((3, 3)), torch.rand(3),
                                torch.tensor([-1, 0, 1]))
        splits = neg_frames.split_in_frames(None, None)
        assert splits.__len__() == 3
        splits = neg_frames.split_in_frames(0, None)
        assert splits.__len__() == 2

    def test_adjacent_frame_split(self):
        xyz = torch.rand((500, 3))
        phot = torch.rand_like(xyz[:, 0])
        frame_ix = torch.randint_like(xyz[:, 0], low=-1, high=2).int()
        em = EmitterSet(xyz, phot, frame_ix)

        em_split = em.split_in_frames(-1, 1)
        assert (em_split[0].frame_ix == -1).all()
        assert (em_split[1].frame_ix == 0).all()
        assert (em_split[2].frame_ix == 1).all()

        em_split = em.split_in_frames(0, 0)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == 0).all()

        em_split = em.split_in_frames(-1, -1)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == -1).all()

        em_split = em.split_in_frames(1, 1)
        assert em_split.__len__() == 1
        assert (em_split[0].frame_ix == 1).all()

    def test_cat_emittersets(self):

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, None, 1)
        assert 70 == len(cat_sets)
        assert 0 == cat_sets.frame_ix[0]
        assert 1 == cat_sets.frame_ix[50]

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, torch.tensor([5, 50]), None)
        assert 70 == len(cat_sets)
        assert 5 == cat_sets.frame_ix[0]
        assert 50 == cat_sets.frame_ix[50]

        # test correctness of px size and xy unit
        sets = [
            RandomEmitterSet(50, xy_unit='px', px_size=(100., 200.)),
            RandomEmitterSet(20)
        ]
        em = EmitterSet.cat(sets)
        assert em.xy_unit == 'px'
        assert (em.px_size == torch.tensor([100., 200.])).all()

    def test_split_cat(self):
        """
        Tests whether split and cat (and sort by ID) returns the same result as the original starting.

        """

        em = RandomEmitterSet(1000)
        em.id = torch.arange(len(em))
        em.frame_ix = torch.randint_like(em.frame_ix, 10000)
        """Run"""
        em_split = em.split_in_frames(0, 9999)
        em_re_merged = EmitterSet.cat(em_split)
        """Assertions"""
        # sort both by id
        ix = torch.argsort(em.id)
        ix_re = torch.argsort(em_re_merged.id)

        assert em[ix] == em_re_merged[ix_re]

    @pytest.mark.parametrize("frac", [0., 0.1, 0.5, 0.9, 1.])
    def test_sigma_filter(self, frac):
        """Setup"""
        em = emitter.RandomEmitterSet(10000)
        em.xyz_sig = (torch.randn_like(em.xyz_sig) + 5).clamp(0.)
        """Run"""
        out = em.filter_by_sigma(fraction=frac)
        """Assert"""
        assert len(em) * frac == pytest.approx(len(out))

    def test_hist_detection(self):

        em = emitter.RandomEmitterSet(10000)
        em.prob = torch.rand_like(em.prob)
        em.xyz_sig = torch.randn_like(em.xyz_sig) * torch.tensor(
            [1., 2., 3.]).unsqueeze(0)
        """Run"""
        out = em.hist_detection()
        """Assert"""
        assert set(out.keys()) == {'prob', 'sigma_x', 'sigma_y', 'sigma_z'}

    def test_sanity_check(self):
        """Test correct shape of 1D tensors in EmitterSet"""
        xyz = torch.rand((10, 3))
        phot = torch.rand((10, 1))
        frame_ix = torch.rand(10)
        with pytest.raises(ValueError):
            EmitterSet(xyz, phot, frame_ix)
        """Test correct number of el. in EmitterSet."""
        xyz = torch.rand((10, 3))
        phot = torch.rand((11, 1))
        frame_ix = torch.rand(10)
        with pytest.raises(ValueError):
            EmitterSet(xyz, phot, frame_ix)

    @pytest.mark.parametrize("em", [
        emitter.RandomEmitterSet(25, 64, px_size=(100., 125.)),
        emitter.EmptyEmitterSet(xy_unit='nm', px_size=(100., 125.))
    ])
    def test_inplace_replace(self, em):
        em_start = emitter.RandomEmitterSet(25, xy_unit='px', px_size=None)
        em_start._inplace_replace(em)

        assert em_start == em

    @pytest.mark.parametrize("format", ['.pt', '.h5', '.csv'])
    @pytest.mark.filterwarnings(
        "ignore:.*For .csv files, implicit usage of .load()")
    def test_save_load(self, format, tmpdir):

        em = RandomEmitterSet(1000, xy_unit='nm', px_size=(100., 100.))

        p = Path(tmpdir / f'em{format}')
        em.save(p)
        em_load = EmitterSet.load(p)
        assert em == em_load, "Reloaded emitterset is not equivalent to inital one."

    @pytest.mark.parametrize(
        "em_a,em_b,expct",
        [(CoordinateOnlyEmitter(torch.tensor([[0., 1., 2.]])),
          CoordinateOnlyEmitter(torch.tensor([[0., 1., 2.]])), True),
         (CoordinateOnlyEmitter(torch.tensor([[0., 1., 2.]]), xy_unit='px'),
          CoordinateOnlyEmitter(torch.tensor([[0., 1., 2.]]),
                                xy_unit='nm'), False),
         (CoordinateOnlyEmitter(torch.tensor([[0., 1., 2.]]), xy_unit='px'),
          CoordinateOnlyEmitter(torch.tensor([[0., 1.1, 2.]]),
                                xy_unit='px'), False)])
    def test_eq(self, em_a, em_b, expct):

        if expct:
            assert em_a == em_b
        else:
            assert not (em_a == em_b)

    def test_meta(self):

        em = RandomEmitterSet(100, xy_unit='nm', px_size=(100., 200.))
        assert set(em.meta.keys()) == {'xy_unit', 'px_size'}

    def test_data(self):
        return  # implicitly in test_to_dict

    def test_to_dict(self):

        em = RandomEmitterSet(100, xy_unit='nm', px_size=(100., 200.))
        """Check whether doing one round of to_dict and back works"""
        em_clone = em.clone()

        em_dict = EmitterSet(**em.to_dict())
        assert em_clone == em_dict
Ejemplo n.º 15
0
    def test_inplace_replace(self, em):
        em_start = emitter.RandomEmitterSet(25, xy_unit='px', px_size=None)
        em_start._inplace_replace(em)

        assert em_start == em
Ejemplo n.º 16
0
 def __getitem__(self, item):
     return torch.rand((1, 64, 64)), \
            torch.rand((6, 64, 64)), \
            torch.rand((6, 64, 64)), \
            emitter.RandomEmitterSet(32)
Ejemplo n.º 17
0
 def dummy_sampler():
     return emitter.RandomEmitterSet(20)
Ejemplo n.º 18
0
    def test_add(self, mock_add):
        em_0 = emitter.RandomEmitterSet(20)
        em_1 = emitter.RandomEmitterSet(100)

        _ = em_0 + em_1
        mock_add.assert_called_once_with((em_0, em_1), None, None)