Example #1
0
    def test_jvp_should_transform_stacks(self):
        def f(x):
            with jax.named_scope('bar'):
                with jax.named_scope('baz'):
                    return jnp.square(x)

        g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, )))
        jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack),
                         'foo/jvp(bar)/baz')
Example #2
0
    def test_jvp_should_apply_to_call_jaxpr(self):
        @jax.jit
        def f(x):
            with jax.named_scope('bar'):
                with jax.named_scope('baz'):
                    return jnp.square(x)

        g = jax.named_scope('foo')(lambda x, t: jax.jvp(f, (x, ), (t, )))
        jaxpr = jax.make_jaxpr(g)(1., 1.).jaxpr
        self.assertEqual(str(jaxpr.eqns[0].source_info.name_stack), 'foo')
        self.assertEqual(
            str(jaxpr.eqns[0].params['call_jaxpr'].eqns[0].source_info.
                name_stack), 'bar/baz')

        hlo_text = _get_hlo(g)(1., 1.)
        self.assertIn('foo/jvp(jit(f))/bar/baz/mul', hlo_text)
Example #3
0
 def f(x):
     with jax.named_scope('foo'):
         y = x + 1
     with jax.named_scope('bar'):
         with jax.named_scope('baz'):
             return y + 1
Example #4
0
 def f(x):
     with jax.named_scope('bar'):
         return x + 1
Example #5
0
 def f(x):
     with jax.named_scope('bar'):
         return jnp.sin(x)
Example #6
0
 def f(x):
     with jax.named_scope('foo'):
         return 2 * jnp.sin(x)
Example #7
0
 def f(x):
     with jax.named_scope('bar'):
         with jax.named_scope('baz'):
             return jnp.square(x)
Example #8
0
 def f(x):
     with jax.named_scope('foo'):
         return x + 1