def test_jvp_should_transform_stacks(self): def f(x): with jax.named_scope('bar'): with jax.named_scope('baz'): return jnp.square(x) g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, ))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo/jvp(bar)/baz')
def test_jvp_should_apply_to_call_jaxpr(self): @jax.jit def f(x): with jax.named_scope('bar'): with jax.named_scope('baz'): return jnp.square(x) g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, ))) jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo') self.assertEqual( str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info. name_stack), 'bar/baz') hlo_text = _get_hlo(g)(1., 1.) self.assertIn('foo/jvp(jit(f))/bar/baz/mul', hlo_text)
def f(x): with jax.named_scope('foo'): y = x + 1 with jax.named_scope('bar'): with jax.named_scope('baz'): return y + 1
def f(x): with jax.named_scope('bar'): return x + 1
def f(x): with jax.named_scope('bar'): return jnp.sin(x)
def f(x): with jax.named_scope('foo'): return 2 * jnp.sin(x)
def f(x): with jax.named_scope('bar'): with jax.named_scope('baz'): return jnp.square(x)
def f(x): with jax.named_scope('foo'): return x + 1