def test_duplicate_names_error_in_strict_mode(self): def f(x): summary.summary(x, name='x') summary.summary(x, name='x') return x with self.assertRaisesRegex(ValueError, 'has already been reaped: x'): summary.get_summaries(f)(2.)
def test_can_pull_out_non_dependent_values(self): def f(x): summary.summary(x**2, name='y') return x _, summaries = summary.get_summaries(f)(2.) self.assertDictEqual(dict(y=4.), summaries)
def test_can_append_to_growing_list_with_summary(self): def f(x): summary.summary(x + 1., name='x', mode='append') summary.summary(x + 2., name='x', mode='append') return x _, summaries = summary.get_summaries(f)(2.) self.assertSetEqual(set(summaries.keys()), {'x'}) np.testing.assert_allclose(summaries['x'], np.array([3., 4.]))
def test_can_pull_summaries_out_of_scan_in_append_mode(self): def f(x): def body(x, _): summary.summary(x, name='x', mode='append') return x + 1, () return lax.scan(body, x, jnp.arange(10.))[0] value, summaries = summary.get_summaries(f)(0.) self.assertEqual(value, 10.) np.testing.assert_allclose(summaries['x'], np.arange(10.))
def test_can_pull_out_summarized_values_in_strict_mode(self): def f(x): return summary.summary(x, name='x') _, summaries = summary.get_summaries(f)(1.) self.assertDictEqual(dict(x=1.), summaries)