Пример #1
0
 def test_single_summary(self):
     s1 = Summary('ex2')
     s1.history['train']['acc'][10] = 0.55
     result = average_summaries(name="ex1", summaries=[s1])
     with self.subTest("Should return a summary instance"):
         self.assertTrue(isinstance(result, Summary))
     with self.subTest("Name should be correct"):
         self.assertEqual("ex1", result.name)
     with self.subTest("History should be correct"):
         self.assertEqual(0.55, result.history['train']['acc'][10])
Пример #2
0
    def test_multi_summaries(self):
        s1 = Summary('s1')
        s2 = Summary('s2')

        s1.history['train']['acc'] = {
            0: 0.2,
            10: 0.4,
            20: 0.7,
            45: ValWithError(0.8, 0.9, 1.0)
        }
        s2.history['train']['acc'] = {0: 0.3, 10: 0.4, 20: 0.9, 30: 1.0}

        s1.history['test']['mcc'] = {40: 0.834}

        s1.history['eval']['wombats'] = {5: 4, 10: 9}
        s2.history['eval']['wombats'] = {5: '3 wombats', 10: '7 wombats'}

        s_merge = average_summaries('s', [s1, s2])

        with self.subTest("Should return a summary instance"):
            self.assertTrue(isinstance(s_merge, Summary))
        with self.subTest("Name should be correct"):
            self.assertEqual("s", s_merge.name)
        with self.subTest("Paired datapoints should be matched"):
            self.assertEqual({0, 10, 20, 30, 45},
                             s_merge.history['train']['acc'].keys())
        with self.subTest(
                "Paired datapoints should have correct mean and std"):
            self.assertEqual(0.25, s_merge.history['train']['acc'][0].y)
            self.assertEqual(
                0.17929, round(s_merge.history['train']['acc'][0].y_min, 5))
            self.assertEqual(
                0.32071, round(s_merge.history['train']['acc'][0].y_max, 5))
        with self.subTest(
                "Paired datapoints with zero std should be handled correctly"):
            self.assertEqual(0.4, s_merge.history['train']['acc'][10].y)
            self.assertEqual(
                0.4, round(s_merge.history['train']['acc'][10].y_min, 5))
            self.assertEqual(
                0.4, round(s_merge.history['train']['acc'][10].y_max, 5))
        with self.subTest(
                "Partially paired datapoints should be handled correctly"):
            self.assertEqual(0.9, s_merge.history['train']['acc'][45])
            self.assertEqual(1.0, s_merge.history['train']['acc'][30])
        with self.subTest("Unpaired datapoints should be handled correctly"):
            self.assertEqual(0.834, s_merge.history['test']['mcc'][40])
        with self.subTest("String values should be handled correctly"):
            self.assertEqual(3.5, s_merge.history['eval']['wombats'][5].y)
            self.assertEqual(
                2.79289, round(s_merge.history['eval']['wombats'][5].y_min, 5))
            self.assertEqual(8.0, s_merge.history['eval']['wombats'][10].y)
Пример #3
0
 def test_empty_summaries(self):
     result = average_summaries(name="ex1", summaries=[])
     with self.subTest("Should return a summary instance"):
         self.assertTrue(isinstance(result, Summary))
     with self.subTest("Name should be correct"):
         self.assertEqual("ex1", result.name)