def test_simple_debug_print_works_in_eager_mode(self):
   def f(x):
     debug_print('x: {}', x)
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), "x: 2\n")
 def test_can_print_inside_switch(self, ordered):
   def f(x):
     def b1(x):
       debug_print("b1: {}", x, ordered=ordered)
       return x
     def b2(x):
       debug_print("b2: {}", x, ordered=ordered)
       return x
     def b3(x):
       debug_print("b3: {}", x, ordered=ordered)
       return x
     return lax.switch(x, (b1, b2, b3), x)
   with capture_stdout() as output:
     f(0)
   self.assertEqual(output(), _format_multiline("""
     b1: 0
     """))
   with capture_stdout() as output:
     f(1)
     jax.effects_barrier()
   self.assertEqual(output(), _format_multiline("""
     b2: 1
     """))
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), _format_multiline("""
     b3: 2
     """))
 def test_debug_print_works_with_named_format_strings(self):
   def f(x):
     debug_print('x: {x}', x=x)
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), "x: 2\n")
  def test_can_print_inside_while_loop_cond(self, ordered):
    def f(x):
      def _cond(x):
        debug_print("x: {x}", x=x, ordered=ordered)
        return x < 10
      def _body(x):
        return x + 1
      return lax.while_loop(_cond, _body, x)
    with capture_stdout() as output:
      f(5)
      jax.effects_barrier()
    self.assertEqual(output(), _format_multiline("""
      x: 5
      x: 6
      x: 7
      x: 8
      x: 9
      x: 10
      """))

    with capture_stdout() as output:
      f(10)
      jax.effects_barrier()
    # Should run the cond once
    self.assertEqual(output(), _format_multiline("""
      x: 10
      """))
示例#5
0
    def test_can_print_inside_cond(self, ordered):
        def f(x):
            def true_fun(x):
                debug_print("true: {}", x, ordered=ordered)
                return x

            def false_fun(x):
                debug_print("false: {}", x, ordered=ordered)
                return x

            return lax.cond(x < 5, true_fun, false_fun, x)

        with capture_stdout() as output:
            f(5)
            jax.effects_barrier()
        self.assertEqual(output(),
                         _format_multiline("""
      false: 5
      """))
        with capture_stdout() as output:
            f(4)
            jax.effects_barrier()
        self.assertEqual(output(),
                         _format_multiline("""
      true: 4
      """))
 def test_debug_print_batching_with_diff_axes(self):
   @functools.partial(jax.vmap, in_axes=(0, 1))
   def f(x, y):
     debug_print('hello: {} {}', x, y)
   with capture_stdout() as output:
     f(jnp.arange(2), jnp.arange(2)[None])
     jax.effects_barrier()
   self.assertEqual(output(), "hello: 0 [0]\nhello: 1 [1]\n")
 def test_debug_print_batching(self):
   @jax.vmap
   def f(x):
     debug_print('hello: {}', x)
   with capture_stdout() as output:
     f(jnp.arange(2))
     jax.effects_barrier()
   self.assertEqual(output(), "hello: 0\nhello: 1\n")
 def test_can_stage_out_debug_print(self):
   @jax.jit
   def f(x):
     debug_print('x: {x}', x=x)
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), "x: 2\n")
 def test_multiple_debug_prints_should_print_multiple_values(self):
   def f(x):
     debug_print('x: {x}', x=x)
     debug_print('y: {y}', y=x + 1)
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), "x: 2\ny: 3\n")
 def test_can_stage_out_ordered_print_with_pytree(self):
   @jax.jit
   def f(x):
     struct = dict(foo=x)
     debug_print('x: {}', struct, ordered=True)
   with capture_stdout() as output:
     f(np.array(2, np.int32))
     jax.effects_barrier()
   self.assertEqual(output(), f"x: {str(dict(foo=np.array(2, np.int32)))}\n")
 def test_debug_print_transpose_rule(self):
   def f(x):
     debug_print('should never be called: {}', x)
     return x
   with capture_stdout() as output:
     jax.linear_transpose(f, 1.)(1.)
     jax.effects_barrier()
   # `debug_print` should be dropped by `partial_eval` because of no
   # output data-dependence.
   self.assertEqual(output(), "")
  def test_unordered_print_with_xmap(self):

    def f(x):
      debug_print("{}", x, ordered=False)
    f = maps.xmap(f, in_axes=['a'], out_axes=None, backend='cpu',
                  axis_resources={'a': 'dev'})
    with maps.Mesh(np.array(jax.devices(backend='cpu')), ['dev']):
      with capture_stdout() as output:
        f(jnp.arange(40))
        jax.effects_barrier()
      lines = [f"{i}\n" for i in range(40)]
      self._assertLinesEqual(output(), "".join(lines))
 def test_can_print_inside_scan(self, ordered):
   def f(xs):
     def _body(carry, x):
       debug_print("carry: {carry}, x: {x}", carry=carry, x=x, ordered=ordered)
       return carry + 1, x + 1
     return lax.scan(_body, 2, xs)
   with capture_stdout() as output:
     f(jnp.arange(2))
     jax.effects_barrier()
   self.assertEqual(
       output(),
       _format_multiline("""
     carry: 2, x: 0
     carry: 3, x: 1
     """))
 def test_can_print_inside_for_loop(self, ordered):
   def f(x):
     def _body(i, x):
       debug_print("x: {x}", x=x, ordered=ordered)
       return x + 1
     return lax.fori_loop(0, 5, _body, x)
   with capture_stdout() as output:
     f(2)
     jax.effects_barrier()
   self.assertEqual(output(), _format_multiline("""
     x: 2
     x: 3
     x: 4
     x: 5
     x: 6
     """))
示例#15
0
  def test_unordered_print_works_in_pmap(self):
    if jax.device_count() < 2:
      raise unittest.SkipTest("Test requires >= 2 devices.")

    @jax.pmap
    def f(x):
      debug_print("hello: {}", x, ordered=False)
    with capture_stdout() as output:
      f(jnp.arange(jax.local_device_count()))
      jax.effects_barrier()
    self._assertLinesEqual(output(), "hello: 0\nhello: 1\n")

    @jax.pmap
    def f2(x):
      debug_print('hello: {}', x)
      debug_print('hello: {}', x + 2)
    with capture_stdout() as output:
      f2(jnp.arange(2))
      jax.effects_barrier()
    self._assertLinesEqual(output(), "hello: 0\nhello: 1\nhello: 2\nhello: 3\n")
 def tested_debug_print_with_nested_vmap(self):
   def f(x):
     debug_print('hello: {}', x)
   # Call with
   # [[0, 1],
   #  [2, 3],
   #  [4, 5]]
   with capture_stdout() as output:
     # Should print over 0-axis then 1-axis
     jax.vmap(jax.vmap(f))(jnp.arange(6).reshape((3, 2)))
     jax.effects_barrier()
   self.assertEqual(
       output(),
       "hello: 0\nhello: 2\nhello: 4\nhello: 1\nhello: 3\nhello: 5\n")
   with capture_stdout() as output:
     # Should print over 1-axis then 0-axis
     jax.vmap(jax.vmap(f, in_axes=0), in_axes=1)(jnp.arange(6).reshape((3, 2)))
     jax.effects_barrier()
   self.assertEqual(
       output(),
       "hello: 0\nhello: 1\nhello: 2\nhello: 3\nhello: 4\nhello: 5\n")
示例#17
0
    def test_can_use_multiple_breakpoints(self):
        stdin, stdout = make_fake_stdin_stdout(["p y", "c", "p y", "c"])

        def f(x):
            y = x + 1.
            debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
            return y

        @jax.jit
        def g(x):
            y = f(x) * 2.
            debugger.breakpoint(stdin=stdin, stdout=stdout, ordered=True)
            return jnp.exp(y)

        expected = _format_multiline(r"""
    Entering jaxdb:
    (jaxdb) array(3., dtype=float32)
    (jaxdb) Entering jaxdb:
    (jaxdb) array(6., dtype=float32)
    (jaxdb) """)
        g(jnp.array(2., jnp.float32))
        jax.effects_barrier()
        self.assertEqual(stdout.getvalue(), expected)