def test_primitives_by_shape(self): def f(x, y): def sub(x, y): return jnp.sum(jnp.array([x, y])), y s, _ = jit(sub)(x, y) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr) shapes = [ 'add :: float32[]', 'sin :: float32[]', 'cos :: float32[]', 'reduce_sum :: float32[]', 'concatenate :: float32[2]', 'xla_call :: float32[] *', ] for k in shapes: self.assertEqual(hist[k], 1)
def test_primitives_by_shape(self): def f(x, y): def sub(x, y): return jnp.sum(jnp.array([x, y])), y s, _ = jit(sub)(x, y) return jnp.sin(s) + jnp.cos(y) hist = jaxpr_util.primitives_by_shape(make_jaxpr(f)(1., 1.).jaxpr) t = '64' if FLAGS.jax_enable_x64 else '32' shapes = [ f'add :: float{t}[]', f'sin :: float{t}[]', f'cos :: float{t}[]', f'reduce_sum :: float{t}[]', f'concatenate :: float{t}[2]', f'xla_call :: float{t}[] *', ] for k in shapes: self.assertEqual(hist[k], 1)