예제 #1
0
    def test_get_full_name(self):
        self.assertEqual("He-bar", self.rec_elem.get_full_name())

        rec_elem = RecoilElement(Element.from_string("16O"), [], name=None)
        self.assertEqual("16O-Default", rec_elem.get_full_name())

        rec_elem = RecoilElement(Element.from_string("16O"), [], name="")
        self.assertEqual("16O-Default", rec_elem.get_full_name())
예제 #2
0
class TestRecoilElement(unittest.TestCase):
    def setUp(self):
        self.timestamp = time.time()
        self.rec_type = "rec"
        self.ch_width = 4
        self.rec_elem = RecoilElement(
            mo.get_element(),
            [Point((0, 4)),
             Point((1, 5)),
             Point((2, 10))],
            color="black",
            description="foo",
            name="bar",
            rec_type="rec",
            reference_density=3,
            channel_width=self.ch_width,
            modification_time=self.timestamp
        )

    def test_get_full_name(self):
        self.assertEqual("He-bar", self.rec_elem.get_full_name())

        rec_elem = RecoilElement(Element.from_string("16O"), [], name=None)
        self.assertEqual("16O-Default", rec_elem.get_full_name())

        rec_elem = RecoilElement(Element.from_string("16O"), [], name="")
        self.assertEqual("16O-Default", rec_elem.get_full_name())

    def test_serialization(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            file_path = fp.get_recoil_file_path(self.rec_elem, tmp_dir)
            self.rec_elem.to_file(tmp_dir)

            rec_elem2 = RecoilElement.from_file(file_path,
                                                channel_width=self.ch_width,
                                                rec_type=self.rec_type)

            self.compare_rec_elems(self.rec_elem, rec_elem2)

            # Test with an empty list of points and no rec_type or ch_width
            rec_elem3 = RecoilElement(Element.from_string("O"), [])
            file_path = fp.get_recoil_file_path(rec_elem3, tmp_dir)
            rec_elem3.to_file(tmp_dir)
            rec_elem4 = RecoilElement.from_file(file_path)

            self.compare_rec_elems(rec_elem3, rec_elem4)

        self.assertFalse(os.path.exists(tmp_dir))

    def compare_rec_elems(self, rec_elem1, rec_elem2):
        fst = dict(vars(rec_elem1))
        snd = dict(vars(rec_elem2))

        self.assertEqual(fst.pop("_points"), snd.pop("_points"))
        self.assertEqual(fst.pop("element"), snd.pop("element"))

        times = fst.pop("modification_time"), snd.pop("modification_time")

        if None not in times:
            self.assertAlmostEqual(times[0], times[1], places=2)

        self.assertEqual(fst, snd)

    def test_calculate_area(self):
        self.assertEqual(12, self.rec_elem.calculate_area())
        self.assertEqual(4.5, self.rec_elem.calculate_area(
            start=0, end=1))

        self.assertEqual(0, self.rec_elem.calculate_area(
            start=0.5, end=0.5))
        self.assertEqual(2.25, self.rec_elem.calculate_area(
            start=0.25, end=0.75))

        self.assertEqual(5.5, self.rec_elem.calculate_area(
            start=0.5, end=1.5))

        # If the interval is outside the point range, 0 is returned
        self.assertEqual(0, self.rec_elem.calculate_area(
            start=2, end=3))
        self.assertEqual(0, self.rec_elem.calculate_area(
            start=-2, end=0))

        # If the length of the interval is non-positive, 0 is returned
        self.assertEqual(0, self.rec_elem.calculate_area(
            start=1, end=1))
        self.assertEqual(0, self.rec_elem.calculate_area(
            start=1, end=0))

    def test_sorting(self):
        # Checks that recoil elements are sorted in the same way as elements
        n = 10
        iterations = 10
        for _ in range(iterations):
            elems = [mo.get_element(randomize=True) for _ in range(n)]
            rec_elems = [RecoilElement(elem, []) for elem in elems]
            random.shuffle(elems)
            random.shuffle(rec_elems)

            elems.sort()
            rec_elems.sort()
            for e, r in zip(elems, rec_elems):
                self.assertEqual(e, r.element)

    def test_identities(self):
        rec_elem1 = mo.get_recoil_element()
        rec_elem2 = mo.get_recoil_element()

        self.assertNotEqual(rec_elem1, rec_elem2)
        self.assertEqual(rec_elem1, rec_elem1)

        self.assertIs(rec_elem2, rec_elem2)