示例#1
0
 def test_duplicate_names_error_in_strict_mode(self):
   def f(x):
     summary.summary(x, name='x')
     summary.summary(x, name='x')
     return x
   with self.assertRaisesRegex(ValueError, 'has already been reaped: x'):
     summary.get_summaries(f)(2.)
示例#2
0
    def test_can_pull_out_non_dependent_values(self):
        def f(x):
            summary.summary(x**2, name='y')
            return x

        _, summaries = summary.get_summaries(f)(2.)
        self.assertDictEqual(dict(y=4.), summaries)
示例#3
0
 def test_can_append_to_growing_list_with_summary(self):
   def f(x):
     summary.summary(x + 1., name='x', mode='append')
     summary.summary(x + 2., name='x', mode='append')
     return x
   _, summaries = summary.get_summaries(f)(2.)
   self.assertSetEqual(set(summaries.keys()), {'x'})
   np.testing.assert_allclose(summaries['x'], np.array([3., 4.]))
示例#4
0
 def test_can_pull_summaries_out_of_scan_in_append_mode(self):
   def f(x):
     def body(x, _):
       summary.summary(x, name='x', mode='append')
       return x + 1, ()
     return lax.scan(body, x, jnp.arange(10.))[0]
   value, summaries = summary.get_summaries(f)(0.)
   self.assertEqual(value, 10.)
   np.testing.assert_allclose(summaries['x'], np.arange(10.))
示例#5
0
    def test_can_pull_out_summarized_values_in_strict_mode(self):
        def f(x):
            return summary.summary(x, name='x')

        _, summaries = summary.get_summaries(f)(1.)
        self.assertDictEqual(dict(x=1.), summaries)