Esempio n. 1
0
    def test_print_histogram(self):
        def f(x, y):
            s = jit(jnp.sin)(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr)
        jaxpr_util.print_histogram(hist)
Esempio n. 2
0
    def test_primitives_by_source(self):
        def f(x, y):
            s = jnp.sin(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr)

        sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')]
        self.assertEqual(len(sin_keys), 2)
        self.assertTrue(all(count == 1 for count in hist.values()))
Esempio n. 3
0
    def test_primitives_by_source(self):
        def f(x, y):
            s = jnp.sin(x)
            return jnp.sin(s) + jnp.cos(y)

        hist = jaxpr_util.primitives_by_source(make_jaxpr(f)(1., 1.).jaxpr)

        sin_keys = [k for k in hist.keys() if k.startswith('sin @ ')]
        rem_keys = [k for k in hist.keys() if not k.startswith('sin @ ')]

        self.assertEqual(sum(hist[k] for k in sin_keys), 2)
        self.assertTrue(all(hist[k] == 1 for k in rem_keys))