コード例 #1
0
    def test_to_dict(self):

        em = RandomEmitterSet(100, xy_unit='nm', px_size=(100., 200.))
        """Check whether doing one round of to_dict and back works"""
        em_clone = em.clone()

        em_dict = EmitterSet(**em.to_dict())
        assert em_clone == em_dict
コード例 #2
0
    def test_save_load(self, format, tmpdir):

        em = RandomEmitterSet(1000, xy_unit='nm', px_size=(100., 100.))

        p = Path(tmpdir / f'em{format}')
        em.save(p)
        em_load = EmitterSet.load(p)
        assert em == em_load, "Reloaded emitterset is not equivalent to inital one."
コード例 #3
0
    def test_split_cat(self):
        """
        Tests whether split and cat (and sort by ID) returns the same result as the original starting.

        """

        em = RandomEmitterSet(1000)
        em.id = torch.arange(len(em))
        em.frame_ix = torch.randint_like(em.frame_ix, 10000)
        """Run"""
        em_split = em.split_in_frames(0, 9999)
        em_re_merged = EmitterSet.cat(em_split)
        """Assertions"""
        # sort both by id
        ix = torch.argsort(em.id)
        ix_re = torch.argsort(em_re_merged.id)

        assert em[ix] == em_re_merged[ix_re]
コード例 #4
0
def test_remove_out_of_field():
    # Setup
    em = RandomEmitterSet(100000, extent=100)
    em.xyz[:, 2] = torch.rand_like(em.xyz[:, 2]) * 1500 - 750

    # Candidate
    rmf = prep.RemoveOutOfFOV((0., 31.), (7.5, 31.5), (-500, 700))

    # Run and Test
    em_out = rmf.forward(em)

    assert len(em_out) <= len(em)

    assert (em_out.xyz[:, 0] >= 0.).all()
    assert (em_out.xyz[:, 1] >= 7.5).all()
    assert (em_out.xyz[:, 2] >= -500.).all()

    assert (em_out.xyz[:, 0] < 31.).all()
    assert (em_out.xyz[:, 1] < 31.5).all()
    assert (em_out.xyz[:, 2] < 700.).all()
コード例 #5
0
    def test_chunk(self):

        big_em = RandomEmitterSet(100000)

        splits = big_em.chunks(10000)
        re_merged = EmitterSet.cat(splits)

        assert sum([len(e) for e in splits]) == len(big_em)
        assert re_merged == big_em

        # test not evenly splittable number
        em = RandomEmitterSet(7)
        splits = em.chunks(3)

        assert len(splits[0]) == 3
        assert len(splits[1]) == 2
        assert len(splits[-1]) == 2
コード例 #6
0
    def test_cat_emittersets(self):

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, None, 1)
        assert 70 == len(cat_sets)
        assert 0 == cat_sets.frame_ix[0]
        assert 1 == cat_sets.frame_ix[50]

        sets = [RandomEmitterSet(50), RandomEmitterSet(20)]
        cat_sets = EmitterSet.cat(sets, torch.tensor([5, 50]), None)
        assert 70 == len(cat_sets)
        assert 5 == cat_sets.frame_ix[0]
        assert 50 == cat_sets.frame_ix[50]

        # test correctness of px size and xy unit
        sets = [
            RandomEmitterSet(50, xy_unit='px', px_size=(100., 200.)),
            RandomEmitterSet(20)
        ]
        em = EmitterSet.cat(sets)
        assert em.xy_unit == 'px'
        assert (em.px_size == torch.tensor([100., 200.])).all()
コード例 #7
0
    def test_meta(self):

        em = RandomEmitterSet(100, xy_unit='nm', px_size=(100., 200.))
        assert set(em.meta.keys()) == {'xy_unit', 'px_size'}