def test_print_histogram(self): def f(x, y): s = jit(jnp.sin)(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr) jaxpr_util.print_histogram(hist)
def test_primitives_by_source(self): def f(x, y): s = jnp.sin(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr) sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')] self.assertEqual(len(sin_keys), 2) self.assertTrue(all(count == 1 for count in hist.values()))
def test_primitives_by_source(self): def f(x, y): s = jnp.sin(x) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr) sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')] rem_keys = [k for k in hist.keys() if not k.startswith('sin @ ')] self.assertEqual(sum(hist[k] for k in sin_keys), 2) self.assertTrue(all(hist[k] == 1 for k in rem_keys))