示例#1
0
 def test_scalar_collect_batch_weight_broadcast(self):
     collector = StatisticsCollector()
     collector.collect([2, 1, 7, 6], weight=1.)
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 4)
     self.assertAlmostEqual(collector.mean, 4.)
     self.assertAlmostEqual(collector.square, 22.5)
     self.assertAlmostEqual(collector.weight_sum, 4.)
示例#2
0
 def test_collect_empty_weight(self):
     collector = StatisticsCollector()
     collector.collect([2, 1, 7, 6], weight=[])
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 4)
     self.assertAlmostEqual(collector.mean, 4.)
     self.assertAlmostEqual(collector.square, 22.5)
     self.assertAlmostEqual(collector.weight_sum, 4.)
示例#3
0
 def test_scalar_collect_batch(self):
     collector = StatisticsCollector()
     collector.collect([2, 1, 7], weight=[1, 3, 6])
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 3)
     self.assertAlmostEqual(collector.mean, 4.7)
     self.assertAlmostEqual(collector.square, 30.1)
     self.assertAlmostEqual(collector.weight_sum, 10.)
示例#4
0
 def test_vector_collect(self):
     collector = StatisticsCollector(shape=(3, 2))
     arr = np.arange(6).reshape([3, 2])
     collector.collect(arr)
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 1)
     np.testing.assert_almost_equal(collector.mean, arr)
     np.testing.assert_almost_equal(collector.square, arr**2)
     self.assertAlmostEqual(collector.weight_sum, 1.)
示例#5
0
 def test_collect_empty(self):
     collector = StatisticsCollector()
     collector.collect([])
     self.assertEqual(collector.shape, ())
     self.assertFalse(collector.has_value)
     self.assertEqual(collector.counter, 0)
     self.assertAlmostEqual(collector.mean, 0.)
     self.assertAlmostEqual(collector.square, 0.)
     self.assertEqual(collector.weight_sum, 0.)
示例#6
0
 def test_vector_collect_batch_weight_broadcast(self):
     collector = StatisticsCollector(shape=(3, 2))
     arr = np.arange(24).reshape([4, 3, 2])
     collector.collect(arr, weight=1.)
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 4)
     np.testing.assert_almost_equal(collector.mean,
                                    [[9, 10], [11, 12], [13, 14]])
     np.testing.assert_almost_equal(
         collector.square, [[126., 145.], [166., 189.], [214., 241.]])
     self.assertAlmostEqual(collector.weight_sum, 4.)
示例#7
0
 def test_vector_collect_batch(self):
     collector = StatisticsCollector(shape=(3, 2))
     arr = np.arange(24).reshape([4, 3, 2])
     collector.collect(arr, weight=[1, 2, 3, 4])
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 4)
     np.testing.assert_almost_equal(collector.mean,
                                    [[12, 13], [14, 15], [16, 17]])
     np.testing.assert_almost_equal(
         collector.square, [[180., 205.], [232., 261.], [292., 325.]])
     self.assertAlmostEqual(collector.weight_sum, 10.)
示例#8
0
 def test_empty_vector(self):
     collector = StatisticsCollector(shape=(3, 2))
     self.assertEqual(collector.shape, (3, 2))
     self.assertFalse(collector.has_value)
     self.assertEqual(collector.counter, 0)
     np.testing.assert_almost_equal(collector.mean, np.zeros([3, 2]))
     np.testing.assert_almost_equal(collector.square, np.zeros([3, 2]))
     self.assertEqual(collector.weight_sum, 0.)
示例#9
0
 def test_scalar_multi_collect(self):
     collector = StatisticsCollector()
     collector.collect(2)
     collector.collect(1, weight=3)
     collector.collect(7, weight=6)
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 3)
     self.assertAlmostEqual(collector.mean, 4.7)
     self.assertAlmostEqual(collector.square, 30.1)
     self.assertAlmostEqual(collector.weight_sum, 10.)
示例#10
0
 def test_reset(self):
     collector = StatisticsCollector()
     collector.collect([2, 1, 7, 6])
     collector.reset()
     self.assertFalse(collector.has_value)
     self.assertEqual(collector.counter, 0)
     self.assertAlmostEqual(collector.mean, 0.)
     self.assertAlmostEqual(collector.square, 0.)
     self.assertAlmostEqual(collector.var, 0.)
     self.assertAlmostEqual(collector.stddev, 0.)
     self.assertAlmostEqual(collector.weight_sum, 0.)
示例#11
0
    def metric_collector(self, metric_name):
        """
        Get a :class:`~tfsnippet.utils.StatisticsCollector` for metric.

        The mean value of the collected metrics will be added to summary
        after exiting the context.  Other statistics will be discarded.

        Args:
            metric_name (str): The name of this metric.

        Yields:
            StatisticsCollector: The collector for metric values.
        """
        self._require_context()
        acc = StatisticsCollector()
        yield acc
        if acc.has_value:
            self.collect_metrics(metrics={metric_name: acc.mean})
示例#12
0
 def test_vector_multi_collect(self):
     collector = StatisticsCollector(shape=(3, 2))
     arr = np.arange(24).reshape([4, 3, 2])
     collector.collect(arr[0])
     collector.collect(arr[1], weight=2)
     collector.collect(arr[2], weight=3)
     collector.collect(arr[3], weight=4)
     self.assertTrue(collector.has_value)
     self.assertEqual(collector.counter, 4)
     np.testing.assert_almost_equal(collector.mean,
                                    [[12, 13], [14, 15], [16, 17]])
     np.testing.assert_almost_equal(
         collector.square, [[180., 205.], [232., 261.], [292., 325.]])
     np.testing.assert_almost_equal(
         collector.var, np.maximum(collector.square - collector.mean**2,
                                   0.))
     np.testing.assert_almost_equal(
         collector.stddev,
         np.sqrt(np.maximum(collector.square - collector.mean**2, 0.)))
     self.assertAlmostEqual(collector.weight_sum, 10.)
示例#13
0
 def test_var_std(self):
     collector = StatisticsCollector()
     collector.collect([2, 1, 7, 6])
     self.assertEqual(collector.counter, 4)
     self.assertAlmostEqual(collector.var, 6.5)
     self.assertAlmostEqual(collector.stddev, 2.549509756796)
示例#14
0
 def test_shape_mismatch(self):
     collector = StatisticsCollector(shape=(3, 2))
     with pytest.raises(
             ValueError,
             match=r'Shape mismatch: \(3,\) not ending with \(3, 2\)'):
         collector.collect([1, 2, 3])