def test_simple_debug_print_works_in_eager_mode(self): def f(x): debug_print('x: {}', x) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), "x: 2\n")
def test_can_print_inside_switch(self, ordered): def f(x): def b1(x): debug_print("b1: {}", x, ordered=ordered) return x def b2(x): debug_print("b2: {}", x, ordered=ordered) return x def b3(x): debug_print("b3: {}", x, ordered=ordered) return x return lax.switch(x, (b1, b2, b3), x) with capture_stdout() as output: f(0) self.assertEqual(output(), _format_multiline(""" b1: 0 """)) with capture_stdout() as output: f(1) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" b2: 1 """)) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" b3: 2 """))
def test_debug_print_works_with_named_format_strings(self): def f(x): debug_print('x: {x}', x=x) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), "x: 2\n")
def test_can_print_inside_while_loop_cond(self, ordered): def f(x): def _cond(x): debug_print("x: {x}", x=x, ordered=ordered) return x < 10 def _body(x): return x + 1 return lax.while_loop(_cond, _body, x) with capture_stdout() as output: f(5) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" x: 5 x: 6 x: 7 x: 8 x: 9 x: 10 """)) with capture_stdout() as output: f(10) jax.effects_barrier() # Should run the cond once self.assertEqual(output(), _format_multiline(""" x: 10 """))
def test_can_print_inside_cond(self, ordered): def f(x): def true_fun(x): debug_print("true: {}", x, ordered=ordered) return x def false_fun(x): debug_print("false: {}", x, ordered=ordered) return x return lax.cond(x < 5, true_fun, false_fun, x) with capture_stdout() as output: f(5) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" false: 5 """)) with capture_stdout() as output: f(4) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" true: 4 """))
def test_debug_print_batching_with_diff_axes(self): @functools.partial(jax.vmap, in_axes=(0, 1)) def f(x, y): debug_print('hello: {} {}', x, y) with capture_stdout() as output: f(jnp.arange(2), jnp.arange(2)[None]) jax.effects_barrier() self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")
def test_debug_print_batching(self): @jax.vmap def f(x): debug_print('hello: {}', x) with capture_stdout() as output: f(jnp.arange(2)) jax.effects_barrier() self.assertEqual(output(), "hello: 0\nhello: 1\n")
def test_can_stage_out_debug_print(self): @jax.jit def f(x): debug_print('x: {x}', x=x) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), "x: 2\n")
def test_multiple_debug_prints_should_print_multiple_values(self): def f(x): debug_print('x: {x}', x=x) debug_print('y: {y}', y=x + 1) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), "x: 2\ny: 3\n")
def test_can_stage_out_ordered_print_with_pytree(self): @jax.jit def f(x): struct = dict(foo=x) debug_print('x: {}', struct, ordered=True) with capture_stdout() as output: f(np.array(2, np.int32)) jax.effects_barrier() self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
def test_debug_print_transpose_rule(self): def f(x): debug_print('should never be called: {}', x) return x with capture_stdout() as output: jax.linear_transpose(f, 1.)(1.) jax.effects_barrier() # `debug_print` should be dropped by `partial_eval` because of no # output data-dependence. self.assertEqual(output(), "")
def test_unordered_print_with_xmap(self): def f(x): debug_print("{}", x, ordered=False) f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu', axis_resources={'a': 'dev'}) with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']): with capture_stdout() as output: f(jnp.arange(40)) jax.effects_barrier() lines = [f"{i}\n" for i in range(40)] self._assertLinesEqual(output(), "".join(lines))
def test_can_print_inside_scan(self, ordered): def f(xs): def _body(carry, x): debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered) return carry + 1, x + 1 return lax.scan(_body, 2, xs) with capture_stdout() as output: f(jnp.arange(2)) jax.effects_barrier() self.assertEqual( output(), _format_multiline(""" carry: 2, x: 0 carry: 3, x: 1 """))
def test_can_print_inside_for_loop(self, ordered): def f(x): def _body(i, x): debug_print("x: {x}", x=x, ordered=ordered) return x + 1 return lax.fori_loop(0, 5, _body, x) with capture_stdout() as output: f(2) jax.effects_barrier() self.assertEqual(output(), _format_multiline(""" x: 2 x: 3 x: 4 x: 5 x: 6 """))
def test_unordered_print_works_in_pmap(self): if jax.device_count() < 2: raise unittest.SkipTest("Test requires >= 2 devices.") @jax.pmap def f(x): debug_print("hello: {}", x, ordered=False) with capture_stdout() as output: f(jnp.arange(jax.local_device_count())) jax.effects_barrier() self._assertLinesEqual(output(), "hello: 0\nhello: 1\n") @jax.pmap def f2(x): debug_print('hello: {}', x) debug_print('hello: {}', x + 2) with capture_stdout() as output: f2(jnp.arange(2)) jax.effects_barrier() self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
def tested_debug_print_with_nested_vmap(self): def f(x): debug_print('hello: {}', x) # Call with # [[0, 1], # [2, 3], # [4, 5]] with capture_stdout() as output: # Should print over 0-axis then 1-axis jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2))) jax.effects_barrier() self.assertEqual( output(), "hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n") with capture_stdout() as output: # Should print over 1-axis then 0-axis jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2))) jax.effects_barrier() self.assertEqual( output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")
def test_can_use_multiple_breakpoints(self): stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"]) def f(x): y = x + 1. debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True) return y @jax.jit def g(x): y = f(x) * 2. debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True) return jnp.exp(y) expected = _format_multiline(r""" Entering jaxdb: (jaxdb) array(3., dtype=float32) (jaxdb) Entering jaxdb: (jaxdb) array(6., dtype=float32) (jaxdb) """) g(jnp.array(2., jnp.float32)) jax.effects_barrier() self.assertEqual(stdout.getvalue(), expected)