示例#1
0
def supported_dtypes(dtypes):
    return [t for t in dtypes if t in jtu.supported_dtypes()]
示例#2
0
def supported_dtypes():
    return sorted(jtu.supported_dtypes(), key=lambda x: np.dtype(x).name)
示例#3
0
def supported_dtypes(dtypes):
    return [
        t for t in dtypes if t in jtu.supported_dtypes() and (
            FLAGS.jax_enable_x64 or np.dtype(t).itemsize != 8)
    ]
示例#4
0
                            complex_dtypes, float_dtypes, inexact_dtypes,
                            num_float_bits)

from jax.config import config
config.parse_flags_with_absl()
FLAGS = config.FLAGS


GradTestSpec = collections.namedtuple(
    "GradTestSpec",
    ["op", "nargs", "order", "rng_factory", "dtypes", "name", "tol"])
def grad_test_spec(op, nargs, order, rng_factory, dtypes, name=None, tol=None):
  return GradTestSpec(
      op, nargs, order, rng_factory, dtypes, name or op.__name__, tol)

grad_float_dtypes = list(jtu.supported_dtypes().intersection(
  {onp.float32, onp.float64}))
grad_complex_dtypes = list(jtu.supported_dtypes().intersection(
  {onp.complex64, onp.complex128}))
grad_inexact_dtypes = grad_float_dtypes + grad_complex_dtypes

LAX_GRAD_OPS = [
    grad_test_spec(lax.neg, nargs=1, order=2, rng_factory=jtu.rand_default,
                   dtypes=grad_inexact_dtypes),
    grad_test_spec(lax.floor, nargs=1, order=2,
                   rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
                   dtypes=grad_float_dtypes),
    grad_test_spec(lax.ceil, nargs=1, order=2,
                   rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
                   dtypes=grad_float_dtypes),
    grad_test_spec(lax.round, nargs=1, order=2,
                   rng_factory=partial(jtu.rand_uniform, low=0.1, high=0.4),
示例#5
0
class HostCallbackTest(jtu.JaxTestCase):
    def setUp(self):
        testing_stream.reset()
        testing_stream.testMethodName = self._testMethodName
        self.old_flags = os.getenv("XLA_FLAGS", "")

    def tearDown(self) -> None:
        if os.getenv("XLA_FLAGS") != self.old_flags:
            os.environ["XLA_FLAGS"] = self.old_flags
            xla_bridge.get_backend.cache_clear()

    def helper_set_devices(self, nr_devices):
        flags_str = os.getenv("XLA_FLAGS", "")
        os.environ["XLA_FLAGS"] = (
            flags_str +
            " --xla_force_host_platform_device_count={}".format(nr_devices))
        # Clear any cached backends so new CPU backend will pick up the env var.
        xla_bridge.get_backend.cache_clear()
        return api.devices()

    def helper_set_hlo_dump(self):
        flags_str = os.getenv("XLA_FLAGS", "")
        os.environ["XLA_FLAGS"] = f"{flags_str} --xla_dump_to=/tmp/xla_dump"
        # Clear any cached backends so new CPU backend will pick up the env var.
        xla_bridge.get_backend.cache_clear()

    def test_eval(self):
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  func=_print
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] d c
      g = pow f 2.00
  in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
        self.assertEqual("", testing_stream.output)

        with hcb.outfeed_receiver():
            self.assertAllClose((5. * 2.)**2, fun1(5.), check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
        testing_stream.reset()

    def test_with_tuple_results(self):
        def func2(x):
            x1, y1 = hcb.id_print((x * 2., x * 3.),
                                  output_stream=testing_stream)
            return x1 + y1

        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = mul a 3.00
      d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    ] b c
      f = add d e
  in (f,) }""", str(api.make_jaxpr(func2)(3.)))
        with hcb.outfeed_receiver():
            self.assertEqual(3. * (2. + 3.), func2(3.))
        assertMultiLineStrippedEqual(self, """
[ 6.00
  9.00 ]""", testing_stream.output)
        testing_stream.reset()

    def test_with_dict_results(self):
        def func2(x):
            res = hcb.id_print(dict(a=x * 2., b=x * 3.),
                               output_stream=testing_stream)
            return res["a"] + res["b"]

        with hcb.outfeed_receiver():
            self.assertEqual(3. * (2. + 3.), func2(3.))
        assertMultiLineStrippedEqual(self, """
{ a=6.00
  b=9.00 }""", testing_stream.output)
        testing_stream.reset()

    def test_with_result(self):
        def func2(x):
            x1 = hcb.id_print((x * 2., x * 3.),
                              result=x * 4.,
                              output_stream=testing_stream)
            return x1

        with hcb.outfeed_receiver():
            self.assertEqual(3. * 4., func2(3.))
        assertMultiLineStrippedEqual(self, """
[ 6.00
  9.00 ]""", testing_stream.output)
        testing_stream.reset()

    def test_eval_tap_exception(self):
        # Simulate a tap error
        def tap_err(*args, **kwargs):
            raise NotImplementedError

        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with self.assertRaises(hcb.TapFunctionException):
            with hcb.outfeed_receiver():
                res = func(0)

        # We should have received everything before the error
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_simple(self):
        jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
            2. * x, what="here", output_stream=testing_stream))

        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(jit_fun1)(5.).GetHloText())
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = jit_fun1(5.)

        self.assertAllClose(6. * 5., res, check_dtypes=True)
        assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
        testing_stream.reset()

    def test_jit_sequence1(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            return hcb.id_print(x1 + 1,
                                where="2",
                                output_stream=testing_stream)

        logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1))
        logging.info("%s: %s", self._testMethodName,
                     api.xla_computation(func)(1).GetHloText())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(2, api.jit(func)(1))
        assertMultiLineStrippedEqual(self, """
where: 1
1
where: 2
2""", testing_stream.output)
        testing_stream.reset()

    def test_jit2(self):
        """A sequence of JIT."""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
            return x2

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(2, api.jit(func)(1))
            self.assertEqual(11, api.jit(func)(10))

        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: 1
10
where: 2
11""", testing_stream.output)
        testing_stream.reset()

    def test_jit_nested(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)

            def func_nested(x):
                x2 = hcb.id_print(x + 1,
                                  where="nested",
                                  output_stream=testing_stream)
                return x2

            x3 = api.jit(func_nested)(x1)
            return hcb.id_print(x3 + 1,
                                where="3",
                                output_stream=testing_stream)

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).GetHloText())
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(3, api.jit(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_devices(self):
        """Running on multiple devices."""
        devices = api.local_devices()
        logging.info(f"{self._testMethodName}: has devices {devices}")

        def func(x, device_id):
            x1 = hcb.id_print(x,
                              dev=str(device_id),
                              output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1,
                              dev=str(device_id),
                              output_stream=testing_stream)
            return x2

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            for d in devices:
                self.assertEqual(
                    112,
                    api.jit(func, device=d, static_argnums=1)(111, d.id))
        logging.info(
            f"{self._testMethodName}: found output {testing_stream.output}")
        self.assertEqual(len(devices),
                         len(re.findall(r"111", testing_stream.output)))
        self.assertEqual(len(devices),
                         len(re.findall(r"112", testing_stream.output)))
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_pytree(self, with_jit=False):
        def func(x, what=""):
            """Returns some pytrees depending on x"""
            if what == "pair_1_x":
                return (1, x)
            elif what == "pair_x_2x":
                return (x, 2 * x)
            elif what == "dict":
                return dict(a=2 * x, b=3 * x)
            else:
                assert False

        tap_count = 0

        def tap_func(a, what=""):
            nonlocal tap_count
            tap_count += 1
            self.assertEqual(func(5, what), a)

        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            for what in ("pair_1_x", "pair_x_2x", "dict"):
                self.assertEqual(
                    func(10, what),
                    transform(lambda x: hcb.id_tap(tap_func,
                                                   func(x, what),
                                                   result=func(x * 2, what),
                                                   what=what))(5))
        # Wait for receivers to be done
        self.assertEqual(3, tap_count)

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_cond(self, with_jit=False):
        """A conditional"""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            x4 = lax.cond(
                x % 2 == 0, x2 + 1, lambda x: hcb.id_print(
                    x, where="cond_t", output_stream=testing_stream), x2 + 1,
                lambda x: hcb.id_print(
                    -1, where="cond_f", result=x, output_stream=testing_stream)
            )
            x5 = hcb.id_print(x4 + 1,
                              where="end",
                              output_stream=testing_stream)
            return x5

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).GetHloText())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: cond_f
-1
where: end
4""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_while_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(x):
                x3 = hcb.id_print(x,
                                  where="w_b_1",
                                  output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="w_b_t", output_stream=testing_stream),
                    x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="w_b_f",
                                           result=x,
                                           output_stream=testing_stream))
                return hcb.id_print(x4,
                                    where="w_b_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(lambda x: x <= 3, body, x2)
            res = hcb.id_print(x10, where="end", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).GetHloText())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
        testing_stream.reset()

    def test_jit_while_pred_printing(self):
        """While with printing in the conditional."""
        raise SkipTest("Not yet implemented")

        #TODO: implement printing inside conditional
        def func(x):
            x1 = hcb.id_print(x, where="1")

            def body(x):
                x3 = hcb.id_print(x, where="w_1", output_stream=testing_stream)
                return hcb.id_print(x3 + 1,
                                    where="w_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 10, where="w_p", output_stream=testing_stream), body,
                x1)
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).GetHloText())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(10, api.jit(func)(1))
        assertMultiLineStrippedEqual(self, """
""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(testcase_name=f"_with_jit_{with_jit}", with_jit=with_jit)
            for with_jit in [True, False]))
    def test_scan_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(c, x):
                x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="s_t", output_stream=testing_stream), x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="s_f",
                                           result=x,
                                           output_stream=testing_stream))
                return (c,
                        hcb.id_print(x4,
                                     where="s_2",
                                     output_stream=testing_stream))

            _, x10 = lax.scan(body, x2, np.arange(3))
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).GetHloText())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            if with_jit:
                func = api.jit(func)
            res = func(1)
            self.assertAllClose(np.array([1, 2, 3]), res, check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: s_1
0
where: s_t
1
where: s_2
1
where: s_1
1
where: s_f
-1
where: s_2
2
where: s_1
2
where: s_t
3
where: s_2
3
where: 10
[1 2 3]""", testing_stream.output)
        testing_stream.reset()

    @parameterized.named_parameters(
        jtu.cases_from_list(
            dict(
                testcase_name=f"_shape_{shape}_dtype_{dtype}_nr_args={nr_args}",
                shape=shape,
                dtype=dtype,
                nr_args=nr_args) for nr_args in [1, 2]
            for shape in [(), (2, ), (2, 3), (2, 3, 4)]
            for dtype in jtu.supported_dtypes()))
    def test_jit_types(self, nr_args=2, dtype=np.int16, shape=(2, )):
        if dtype in (np.complex64, np.complex128, np.bool_):
            raise SkipTest(f"id_print jit not implemented for {dtype}.")
        if jtu.device_under_test() == "tpu":
            if dtype in (np.int16, ):
                raise SkipTest(f"transfering {dtype} not supported on TPU")
        args = [np.arange(np.prod(shape), dtype=dtype).reshape(shape)]
        if nr_args > 1:
            args = args * nr_args
        jit_fun1 = api.jit(lambda xs: hcb.id_print(
            xs,
            a_new_test="************",
            testcase_name=f"shape_{shape}_dtype_{dtype}_nr_args={nr_args}"))
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = jit_fun1(args)
        # self.assertAllClose(args, res, check_dtypes=True)

    def test_jit_large(self):
        arg = np.arange(10000, dtype=np.int32).reshape((10, 10, 5, -1))
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            api.jit(hcb.id_print)(arg)

    def test_jit_several_together(self):
        arg = np.arange(50, dtype=np.int32).reshape((10, 5))
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
                arg, np.ones(100, dtype=np.int32))

    def test_jit_interleaving(self):
        # Several jit's without data dependencies; they may interfere
        count = 0  # Count tap invocations
        nr_arrays = 5

        def tap_func(arg, **kwargs):
            nonlocal count
            assert len(arg) == nr_arrays
            count += 1

        # This is the function that we'll run multiple times
        def func(x, count):
            for i in range(count):
                x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)],
                               i=i)[-1]
            return x

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            x = np.array(1, dtype=onp.int32)
            res = 0
            for i in range(10):
                # No dependencies between the jit invocations
                res += api.jit(lambda x: func(x, 10))(x)
        logging.warning(
            "%s: %s", self._testMethodName,
            api.xla_computation(lambda x: func(x, 5))(1).GetHloText())
        self.assertEqual(100, count)

    def test_jit_tap_exception(self):
        # Simulate a tap error
        def tap_err(*args, **kwargs):
            raise NotImplementedError

        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with self.assertRaises(hcb.TapFunctionException):
            with hcb.outfeed_receiver(receiver_name=self._testMethodName):
                res = api.jit(func)(0)
        # Even though the receiver thread raised, the main thread should still
        # return 3.
        self.assertEqual(3, res)
        # We should have received all others
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()

    def test_jit_unknown_tap(self):
        # Simulate an unknown tap function
        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with self.assertRaises(hcb.TapFunctionException):
            with hcb.outfeed_receiver(receiver_name=self._testMethodName):
                res = api.jit(func)(0)
        # Even though the receiver thread raised, the main thread should still
        # return 3.
        self.assertEqual(3, res)
        # We should have received all others
        assertMultiLineStrippedEqual(self, """
what: x1
1
what: x3
3""", testing_stream.output)
        testing_stream.reset()

    # On CPU and GPU the device code blocks
    # On GPU it seems that there is a 5 min timeout?
    # On TPU the client does not block, but messes up the rest somehow
    @jtu.skip_on_devices("cpu", "gpu", "tpu")
    def test_jit_receiver_ends_prematurely(self):
        # Simulate an unknown tap function
        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(hcb._end_consumer,
                            result=x1 + 1)  # Will end the consumer loop
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = api.jit(func)(0)

        assert False  # It seems that the previous jit blocks above

    def test_jit_error_no_consumer(self):
        # Check for errors if starting jit without a consumer active
        with self.assertRaisesRegex(ValueError,
                                    "outfeed_receiver is not started"):
            api.jit(lambda x: hcb.id_print(x))(0)

    # On CPU and GPU the device code blocks
    # On GPU it seems that there is a 5 min timeout?
    # On TPU the client does not block, but messes up the rest somehow
    @jtu.skip_on_devices("cpu", "gpu", "tpu")
    def test_jit_receiver_ends_prematurely(self):
        # Simulate an unknown tap function
        def func(x):
            x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
            x2 = hcb.id_tap(hcb._end_consumer,
                            result=x1 + 1)  # Will end the consumer loop
            x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
            return x3

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = api.jit(func)(0)

        assert False  # It seems that the previous jit blocks above

    def test_jit_nested_cond_no_print(self):
        """A nested conditional, without any prints"""
        raise SkipTest("skip this")

        @api.jit
        def cfun(x):
            return lax.cond(
                lax.lt(x, 2), x, lambda x: x, x,
                lambda x: lax.cond(x < 5, 3, lambda x: x, 4, lambda y: y))

        print(self._testMethodName, api.xla_computation(cfun)(1).GetHloText())
        cfun(1)

    def test_while(self):
        """Executing while, even without JIT uses compiled code"""
        y = np.ones(5)  # captured const

        def func(x):
            return lax.while_loop(
                lambda c: c[1] < 5, lambda c:
                (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
                (x, 1))

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            func(y)
        assertMultiLineStrippedEqual(self, """
1
2
3
4""", testing_stream.output)
        testing_stream.reset()

    def test_while_error_no_receiver(self):
        """Executing while needs the receiver"""
        y = np.ones(5)  # captured const

        def func(x):
            return lax.while_loop(
                lambda c: c[1] < 5, lambda c:
                (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
                (x, 1))

        with self.assertRaisesRegex(ValueError,
                                    ".*outfeed_receiver.*not started"):
            func(y).block_until_ready()

    def test_jvp(self):
        jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, ))
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a b.
  let c = mul a 2.00
      d = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=a * 2 ] c
      e = mul d 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] e d
      h = pow g 2.00
      i = mul b 2.00
      j k = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=('jvp',)
                    what=a * 2 ] i d
      l = mul j 3.00
      m n o = id_tap[ arg_treedef=*
                      func=_print
                      nr_untapped=2
                      transforms=('jvp',)
                      what=y * 3 ] l j f
      p = pow g 1.00
      q = mul 2.00 p
      r = mul n q
  in (h, r) }""",
            str(api.make_jaxpr(jvp_fun1)(np.float32(5.), np.float32(0.1))))
        with hcb.outfeed_receiver():
            res_primals, res_tangents = jvp_fun1(np.float32(5.),
                                                 np.float32(0.1))
        self.assertAllClose(100., res_primals, check_dtypes=False)
        self.assertAllClose(4., res_tangents, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
transforms: ('jvp',) what: a * 2
0.20
what: y * 3
30.00
transforms: ('jvp',) what: y * 3
0.60""", testing_stream.output)
        testing_stream.reset()

    def test_grad_primal_unused(self):
        # The output of id_print is not needed for backwards pass
        def func(x):
            return 2. * hcb.id_print(
                x * 3., what="x * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        with hcb.outfeed_receiver():
            assertMultiLineStrippedEqual(
                self, """
{ lambda  ; a.
  let
  in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))

        # Just making the Jaxpr invokes the id_print once
        assertMultiLineStrippedEqual(
            self, """
transforms: ('jvp', 'transpose') what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

        with hcb.outfeed_receiver():
            res_grad = grad_func(np.float32(5.))

        self.assertAllClose(6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 3
15.00
transforms: ('jvp', 'transpose') what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul 1.00 a
      c d = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=('jvp', 'transpose')
                    what=y * 3 ] b 0.00
      e = mul c 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=('jvp', 'transpose')
                    what=x * 2 ] e 0.00
      h = mul f 2.00
      i = mul a 2.00
      j = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=x * 2 ] i
      k = mul j 3.00
      l = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=y * 3 ] k
      m = mul 1.00 l
      n = add_any h m
  in (n,) }""", str(api.make_jaxpr(grad_func)(5.)))

        with hcb.outfeed_receiver():
            res_grad = grad_func(np.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ('jvp', 'transpose') what: y * 3
5.00
transforms: ('jvp', 'transpose') what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()

    def test_grad_double(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * (y * 3.)

        grad_func = api.grad(api.grad(func))
        with hcb.outfeed_receiver():
            assertMultiLineStrippedEqual(
                self, """
{ lambda  ; a.
  let 
  in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.)))
            # Just making the Jaxpr invokes the id_print twiceonce
            assertMultiLineStrippedEqual(
                self, """
transforms: ('jvp', 'transpose') what: x * 2
3.00
transforms: ('jvp', 'transpose', 'jvp', 'transpose') what: x * 2
2.00""", testing_stream.output)
            testing_stream.reset()
            res_grad = grad_func(np.float32(5.))

        self.assertAllClose(12., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
transforms: ('jvp', 'transpose') what: x * 2
15.00
transforms: ('jvp', 'transpose', 'jvp', 'transpose') what: x * 2
2.00
transforms: ('jvp', 'transpose') what: x * 2
3.00""", testing_stream.output)
        testing_stream.reset()

    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = np.array([np.float32(4.), np.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  batch_dims=(0,)
                  func=_print
                  transforms=('batch',)
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    batch_dims=(0, 0)
                    func=_print
                    nr_untapped=1
                    transforms=('batch',)
                    what=y * 3 ] d c
      g = pow f 2.00
  in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
        with hcb.outfeed_receiver():
            res_vmap = vmap_fun1(vargs)
        assertMultiLineStrippedEqual(
            self, """
batch_dims: (0,) transforms: ('batch',) what: a * 2
[ 8.00 10.00]
batch_dims: (0, 0) transforms: ('batch',) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()

    def test_vmap_not_batched(self):
        x = 3.

        def func(y):
            # x is not mapped, y is mapped
            _, y = hcb.id_print((x, y), output_stream=testing_stream)
            return x + y

        vmap_func = api.vmap(func)
        vargs = np.array([np.float32(4.), np.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    batch_dims=(None, 0)
                    func=_print
                    transforms=('batch',) ] 3.00 a
      d = add c 3.00
  in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
        with hcb.outfeed_receiver():
            res_vmap = vmap_func(vargs)
        assertMultiLineStrippedEqual(
            self, """
batch_dims: (None, 0) transforms: ('batch',)
[ 3.00
  [4.00 5.00] ]
   """, testing_stream.output)
        testing_stream.reset()

    def test_pmap(self):
        vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)

        pmap_fun1 = api.pmap(fun1, axis_name="i")
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = pmap_fun1(vargs)
        expected_res = np.stack(
            [fun1_equiv(2. + a) for a in range(api.local_device_count())])
        self.assertAllClose(expected_res, res, check_dtypes=False)

    def test_pmap_error_no_receiver(self):
        # Check for errors if starting jit without a consumer active
        vargs = 2. + np.arange(api.local_device_count(), dtype=np.float32)
        with self.assertRaisesRegex(ValueError,
                                    "outfeed_receiver is not started"):
            api.pmap(lambda x: hcb.id_print(x))(vargs)

    def test_mask(self):
        # TODO(necula)
        raise SkipTest("masking has regressed")

        @partial(api.mask, in_shapes=['n'], out_shape='')
        def padded_sum(x):
            return np.sum(
                hcb.id_print(x, what="x", output_stream=testing_stream))

        args = [np.arange(4)], dict(n=onp.int64(2))
        assertMultiLineStrippedEqual(
            self, """
{ lambda c f ; a b.
  let d = lt c b
      e = id_tap[ func=_print
                  logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)]
                  transforms=('mask',)
                  what=x ] a
      g = select d e f
      h = reduce_sum[ axes=(0,) ] g
  in (h,) }""", str(api.make_jaxpr(padded_sum)(*args)))

        res = padded_sum(*args)
        self.assertMultiLineStrippedEqual(
            """
logical_shapes: [(2,)] transforms: ('mask',) what: x
[0 1 2 3]
   """, testing_stream.output)
        testing_stream.reset()