예제 #1
0
    def test_grad_primal_unused(self):
        # The output of id_print is not needed for backwards pass
        def func(x):
            return 2. * hcb.id_print(
                x * 3., what="x * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        with hcb.outfeed_receiver():
            assertMultiLineStrippedEqual(
                self, """
{ lambda  ; a.
  let
  in (6.00,) }""", str(api.make_jaxpr(grad_func)(5.)))

        # Just making the Jaxpr invokes the id_print once
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()

        with hcb.outfeed_receiver():
            res_grad = grad_func(jnp.float32(5.))

        self.assertAllClose(6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 3
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3
2.00""", testing_stream.output)
        testing_stream.reset()
예제 #2
0
    def test_double_vmap(self):
        # A 2D tensor with x[i, j] = i + j using 2 vmap
        def sum(x, y):
            return hcb.id_print(x + y, output_stream=testing_stream)

        def sum_rows(xv, y):
            return api.vmap(sum, in_axes=(0, None))(xv, y)

        def sum_all(xv, yv):
            return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv)

        xv = jnp.arange(5, dtype=np.int32)
        yv = jnp.arange(3, dtype=np.int32)
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a b.
  let c = broadcast_in_dim[ broadcast_dimensions=(1,)
                            shape=(3, 5) ] a
      d = reshape[ dimensions=None
                   new_sizes=(3, 1) ] b
      e = add c d
      f = id_tap[ arg_treedef=*
                  func=_print
                  transforms=(('batch', (0,)), ('batch', (0,))) ] e
  in (f,) }""", str(api.make_jaxpr(sum_all)(xv, yv)))
        with hcb.outfeed_receiver():
            _ = sum_all(xv, yv)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)})
[[0 1 2 3 4]
 [1 2 3 4 5]
 [2 3 4 5 6]]""", testing_stream.output)
        testing_stream.reset()
예제 #3
0
    def test_vmap_not_batched(self):
        x = 3.

        def func(y):
            # x is not mapped, y is mapped
            _, y = hcb.id_print((x, y), output_stream=testing_stream)
            return x + y

        vmap_func = api.vmap(func)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*])
                    func=_print
                    transforms=(('batch', (None, 0)),) ] 3.00 a
      d = add c 3.00
  in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs)))
        with hcb.outfeed_receiver():
            _ = vmap_func(vargs)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (None, 0)},)
[ 3.00
  [4.00 5.00] ]""", testing_stream.output)
        testing_stream.reset()
예제 #4
0
  def test_grad_double(self):
    def func(x):
      y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
      return x * (y * 3.)

    grad_func = api.grad(api.grad(func))
    with hcb.outfeed_receiver():
      _ = api.make_jaxpr(grad_func)(5.)
      # Just making the Jaxpr invokes the id_print twiceonce
      assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00""", testing_stream.output)
      testing_stream.reset()
      res_grad = grad_func(jnp.float32(5.))

    self.assertAllClose(12., res_grad, check_dtypes=False)
    assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
2.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
3.00""", testing_stream.output)
    testing_stream.reset()
예제 #5
0
    def test_vmap_while_tap_cond(self):
        """Vmap of while, with a tap in the conditional."""
        def func(x):
            # like max(x, 2)
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 2, where="w_c", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    x + 1, where="w_b", output_stream=testing_stream), x1)
            res = hcb.id_print(x2, where="3", output_stream=testing_stream)
            return res

        inputs = np.arange(5, dtype=np.int32)
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertAllClose(np.array([2, 2, 2, 3, 4]),
                                api.jit(api.vmap(func))(inputs),
                                check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1
[0 1 2 3 4]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[ True  True False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[1 2 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[ True False False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b
[2 3 3 4 5]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c
[False False False False False]
transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3
[2 2 2 3 4]""", testing_stream.output)
        testing_stream.reset()
예제 #6
0
    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = np.array([np.float32(4.), np.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  batch_dims=(0,)
                  func=_print
                  transforms=('batch',)
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    batch_dims=(0, 0)
                    func=_print
                    nr_untapped=1
                    transforms=('batch',)
                    what=y * 3 ] d c
      g = pow f 2.00
  in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
        with hcb.outfeed_receiver():
            res_vmap = vmap_fun1(vargs)
        assertMultiLineStrippedEqual(
            self, """
batch_dims: (0,) transforms: ('batch',) what: a * 2
[ 8.00 10.00]
batch_dims: (0, 0) transforms: ('batch',) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()
예제 #7
0
    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))

        with hcb.outfeed_receiver():
            res_grad = grad_func(jnp.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()
예제 #8
0
    def test_jit_while_pred_tap(self):
        """While with printing in the conditional."""
        def func(x):
            x1 = hcb.id_print(x, where="1")
            x10 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 3, where="w_p", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    x + 1, where="w_b", output_stream=testing_stream), x1)
            res = hcb.id_print(x10, where="3", output_stream=testing_stream)
            return res

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(3, api.jit(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: w_p
True
where: w_b
2
where: w_p
True
where: w_b
3
where: w_p
False
where: 3
3""", testing_stream.output)
        testing_stream.reset()
예제 #9
0
    def test_jit_interleaving(self):
        # Several jit's without data dependencies; they may interfere
        count = 0  # Count tap invocations
        nr_arrays = 5

        def tap_func(arg, **kwargs):
            nonlocal count
            assert len(arg) == nr_arrays
            count += 1

        # This is the function that we'll run multiple times
        def func(x, count):
            for i in range(count):
                x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)],
                               i=i)[-1]
            return x

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            x = jnp.array(1, dtype=np.int32)
            res = 0
            for i in range(10):
                # No dependencies between the jit invocations
                res += api.jit(lambda x: func(x, 10))(x)
        logging.warning(
            "%s: %s", self._testMethodName,
            api.xla_computation(lambda x: func(x, 5))(1).as_hlo_text())
        self.assertEqual(100, count)
예제 #10
0
    def test_pytree(self, with_jit=False):
        def func(x, what=""):
            """Returns some pytrees depending on x"""
            if what == "pair_1_x":
                return (1, x)
            elif what == "pair_x_2x":
                return (x, 2 * x)
            elif what == "dict":
                return dict(a=2 * x, b=3 * x)
            else:
                assert False

        tap_count = 0

        def tap_func(a, what=""):
            nonlocal tap_count
            tap_count += 1
            self.assertEqual(func(5, what), a)

        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            for what in ("pair_1_x", "pair_x_2x", "dict"):
                self.assertEqual(
                    func(10, what),
                    transform(lambda x: hcb.id_tap(tap_func,
                                                   func(x, what),
                                                   result=func(x * 2, what),
                                                   what=what))(5))
        # Wait for receivers to be done
        self.assertEqual(3, tap_count)
예제 #11
0
    def test_jit_while_pred_printing(self):
        """While with printing in the conditional."""
        raise SkipTest("Not yet implemented")

        #TODO: implement printing inside conditional
        def func(x):
            x1 = hcb.id_print(x, where="1")

            def body(x):
                x3 = hcb.id_print(x, where="w_1", output_stream=testing_stream)
                return hcb.id_print(x3 + 1,
                                    where="w_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(
                lambda x: hcb.id_print(
                    x < 10, where="w_p", output_stream=testing_stream), body,
                x1)
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(10, api.jit(func)(1))
        assertMultiLineStrippedEqual(self, """
""", testing_stream.output)
        testing_stream.reset()
예제 #12
0
    def test_scan_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(c, x):
                x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="s_t", output_stream=testing_stream), x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="s_f",
                                           result=x,
                                           output_stream=testing_stream))
                return (c,
                        hcb.id_print(x4,
                                     where="s_2",
                                     output_stream=testing_stream))

            _, x10 = lax.scan(body, x2, jnp.arange(3))
            res = hcb.id_print(x10, where="10", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            if with_jit:
                func = api.jit(func)
            res = func(1)
            self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: s_1
0
where: s_t
1
where: s_2
1
where: s_1
1
where: s_f
-1
where: s_2
2
where: s_1
2
where: s_t
3
where: s_2
3
where: 10
[1 2 3]""", testing_stream.output)
        testing_stream.reset()
예제 #13
0
    def test_vmap(self):
        vmap_fun1 = api.vmap(fun1)
        vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  func=_print
                  transforms=(('batch', (0,)),)
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('batch', (0, 0)),)
                    what=y * 3 ] d c
      g = integer_pow[ y=2 ] f
  in (g,) }""", str(api.make_jaxpr(vmap_fun1)(vargs)))
        with hcb.outfeed_receiver():
            _ = vmap_fun1(vargs)
        assertMultiLineStrippedEqual(
            self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
[24.00 30.00]""", testing_stream.output)
        testing_stream.reset()
예제 #14
0
    def test_jit_nested(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)

            def func_nested(x):
                x2 = hcb.id_print(x + 1,
                                  where="nested",
                                  output_stream=testing_stream)
                return x2

            x3 = api.jit(func_nested)(x1)
            return hcb.id_print(x3 + 1,
                                where="3",
                                output_stream=testing_stream)

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(3, api.jit(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
        testing_stream.reset()
예제 #15
0
    def test_while(self):
        y = jnp.ones(5)  # captured const

        def func(x):
            return lax.while_loop(lambda c: c[1] < 5, lambda c:
                                  (y, hcb.id_print(c[1]) + 1), (x, 1))

        # TODO: we should not need to start a receiver here!!! I believe this is
        # because of the partial evaluation of while, which calls impl, which
        # uses JIT.
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertRewrite(
                """
{ lambda b ; a e.
  let c d f = while[ body_jaxpr={ lambda  ; c a b f.
                                  let d g = id_tap[ arg_treedef=*
                                                    func=_print
                                                    ] b f
                                      e = add d 1
                                  in (c, e, g) }
                     body_nconsts=1
                     cond_jaxpr={ lambda  ; a b d.
                                  let c = lt b 5
                                  in (c,) }
                     cond_nconsts=0 ] b a 1 e
  in (c, 5, f) }""", func, [y])
예제 #16
0
    def test_while(self):
        ct_body = jnp.ones(5, np.float32)  # captured const for the body
        ct_cond = jnp.ones(5, np.float32)  # captured const for the conditional

        def func(x):
            return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
                                  lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
                                  (x, np.float32(1.)))

        # TODO: we should not need to start a receiver here!!! I believe this is
        # because of the partial evaluation of while, which calls impl, which
        # uses JIT.
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertRewrite(
                """
{ lambda b c ; a f.
  let d e g = while[ body_jaxpr={ lambda  ; c a b f.
                                  let d g = id_tap[ arg_treedef=*
                                                    func=_print
                                                    ] b f
                                      e = add d 1.00
                                  in (c, e, g) }
                     body_nconsts=1
                     cond_jaxpr={ lambda  ; c a b g.
                                  let d = add a c
                                      e = reduce_sum[ axes=(0,) ] d
                                      f = lt b e
                                  in (f,) }
                     cond_nconsts=1 ] b c a 1.00 f
  in (d, e, g) }""", func, [ct_body])
예제 #17
0
    def test_jit_devices(self):
        """Running on multiple devices."""
        devices = api.local_devices()
        logging.info(f"{self._testMethodName}: has devices {devices}")

        def func(x, device_id):
            x1 = hcb.id_print(x,
                              dev=str(device_id),
                              output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1,
                              dev=str(device_id),
                              output_stream=testing_stream)
            return x2

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            for d in devices:
                self.assertEqual(
                    112,
                    api.jit(func, device=d, static_argnums=1)(111, d.id))
        logging.info(
            f"{self._testMethodName}: found output {testing_stream.output}")
        self.assertEqual(len(devices),
                         len(re.findall(r"111", testing_stream.output)))
        self.assertEqual(len(devices),
                         len(re.findall(r"112", testing_stream.output)))
        testing_stream.reset()
예제 #18
0
    def test_eval(self):
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul a 2.00
      c = id_tap[ arg_treedef=*
                  func=_print
                  what=a * 2 ] b
      d = mul c 3.00
      e f = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    what=y * 3 ] d c
      g = integer_pow[ y=2 ] f
  in (g,) }""", str(api.make_jaxpr(fun1)(5.)))
        self.assertEqual("", testing_stream.output)

        with hcb.outfeed_receiver():
            self.assertAllClose((5. * 2.)**2, fun1(5.))
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
        testing_stream.reset()
예제 #19
0
    def test_jit_nested(self):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)

            def func_nested(x):
                x2 = hcb.id_print(x + 1,
                                  where="nested",
                                  output_stream=testing_stream)
                return x2

            x3 = api.jit(func_nested)(x1)
            return hcb.id_print(x3 + 1,
                                where="3",
                                output_stream=testing_stream)

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(3, api.jit(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: nested
2
where: 3
3""", testing_stream.output)
        testing_stream.reset()
예제 #20
0
    def test_cond(self, with_jit=False):
        """A conditional"""
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            x4 = lax.cond(
                x % 2 == 0, lambda x: hcb.id_print(
                    x, where="cond_t", output_stream=testing_stream),
                lambda x: hcb.id_print(
                    -1, where="cond_f", result=x, output_stream=testing_stream
                ), x2 + 1)
            x5 = hcb.id_print(x4 + 1,
                              where="end",
                              output_stream=testing_stream)
            return x5

        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: cond_f
-1
where: end
4""", testing_stream.output)
        testing_stream.reset()
예제 #21
0
    def test_jit_constant(self):
        def func(x):
            return hcb.id_print(42, result=x, output_stream=testing_stream)

        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = xla_call[ backend=None
                    call_jaxpr={ lambda  ; a.
                                 let b c = id_tap[ arg_treedef=*
                                                   func=_print
                                                   nr_untapped=1
                                                   ] 42 a
                                 in (c,) }
                    device=None
                    donated_invars=(False,)
                    name=func ] a
  in (b,) }""", str(api.make_jaxpr(api.jit(func))(5)))
        self.assertEqual("", testing_stream.output)

        with hcb.outfeed_receiver():
            self.assertAllClose(5, api.jit(func)(5))
        assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
        testing_stream.reset()
예제 #22
0
  def test_pmap(self):
    vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)

    pmap_fun1 = api.pmap(fun1, axis_name="i")
    with hcb.outfeed_receiver(receiver_name=self._testMethodName):
      res = pmap_fun1(vargs)
    expected_res = jnp.stack([fun1_equiv(2. + a) for a in range(api.local_device_count())])
    self.assertAllClose(expected_res, res, check_dtypes=False)
예제 #23
0
    def test_outfeed_receiver(self):
        """Test the deprecated outfeed_receiver"""
        with hcb.outfeed_receiver():
            self.assertAllClose((5. * 2.)**2, fun1(5.), check_dtypes=True)
        assertMultiLineStrippedEqual(
            self, """
what: a * 2
10.00
what: y * 3
30.00""", testing_stream.output)
        testing_stream.reset()
예제 #24
0
  def test_jit_constant(self):
    def func(x):
      return hcb.id_print(42, result=x, output_stream=testing_stream)

    #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5)))

    with hcb.outfeed_receiver():
      self.assertAllClose(5, api.jit(func)(5))
    assertMultiLineStrippedEqual(self, """
42""", testing_stream.output)
    testing_stream.reset()
예제 #25
0
  def test_with_dict_results(self):
    def func2(x):
      res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream)
      return res["a"] + res["b"]

    with hcb.outfeed_receiver():
      self.assertEqual(3. * (2. + 3.), func2(3.))
    assertMultiLineStrippedEqual(self, """
{ a=6.00
  b=9.00 }""", testing_stream.output)
    testing_stream.reset()
예제 #26
0
  def test_vmap(self):
    vmap_fun1 = api.vmap(fun1)
    vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)])
    #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs)))
    with hcb.outfeed_receiver():
      _ = vmap_fun1(vargs)
    assertMultiLineStrippedEqual(self, """
transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2
[ 8.00 10.00]
transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3
[24.00 30.00]""", testing_stream.output)
    testing_stream.reset()
예제 #27
0
    def test_while_cond(self, with_jit=False):
        def func(x):
            x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
            x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)

            def body(x):
                x3 = hcb.id_print(x,
                                  where="w_b_1",
                                  output_stream=testing_stream)
                x4 = lax.cond(
                    x % 2 == 0, x3 + 1, lambda x: hcb.id_print(
                        x, where="w_b_t", output_stream=testing_stream),
                    x3 + 1,
                    lambda x: hcb.id_print(-1,
                                           where="w_b_f",
                                           result=x,
                                           output_stream=testing_stream))
                return hcb.id_print(x4,
                                    where="w_b_2",
                                    output_stream=testing_stream)

            x10 = lax.while_loop(lambda x: x <= 3, body, x2)
            res = hcb.id_print(x10, where="end", output_stream=testing_stream)
            return res

        logging.warning("%s: %s", self._testMethodName,
                        api.make_jaxpr(func)(1))
        logging.warning("%s: %s", self._testMethodName,
                        api.xla_computation(func)(1).as_hlo_text())
        transform = api.jit if with_jit else lambda f: f
        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            self.assertEqual(4, transform(func)(1))
        assertMultiLineStrippedEqual(
            self, """
where: 1
1
where: 2
2
where: w_b_1
2
where: w_b_t
3
where: w_b_2
3
where: w_b_1
3
where: w_b_f
-1
where: w_b_2
4
where: end
4""", testing_stream.output)
        testing_stream.reset()
예제 #28
0
    def test_grad_simple(self):
        def func(x):
            y = hcb.id_print(x * 2.,
                             what="x * 2",
                             output_stream=testing_stream)
            return x * hcb.id_print(
                y * 3., what="y * 3", output_stream=testing_stream)

        grad_func = api.grad(func)
        assertMultiLineStrippedEqual(
            self, """
{ lambda  ; a.
  let b = mul 1.00 a
      c d = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('jvp',), ('transpose',))
                    what=y * 3 ] b 0.00
      e = mul c 3.00
      f g = id_tap[ arg_treedef=*
                    func=_print
                    nr_untapped=1
                    transforms=(('jvp',), ('transpose',))
                    what=x * 2 ] e 0.00
      h = mul f 2.00
      i = mul a 2.00
      j = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=x * 2 ] i
      k = mul j 3.00
      l = id_tap[ arg_treedef=*
                  func=_print
                  nr_untapped=0
                  what=y * 3 ] k
      m = mul 1.00 l
      n = add_any h m
  in (n,) }""", str(api.make_jaxpr(grad_func)(5.)))

        with hcb.outfeed_receiver():
            res_grad = grad_func(jnp.float32(5.))
        self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
        assertMultiLineStrippedEqual(
            self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
        testing_stream.reset()
예제 #29
0
  def test_jit_receiver_ends_prematurely(self):
    # Simulate an unknown tap function
    def func(x):
      x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
      x2 = hcb.id_tap(hcb._end_consumer, result=x1 + 1)  # Will end the consumer loop
      x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
      return x3

    with hcb.outfeed_receiver(receiver_name=self._testMethodName):
      _ = api.jit(func)(0)

    assert False  # It seems that the previous jit blocks above
예제 #30
0
    def test_jit_simple(self):
        jit_fun1 = api.jit(lambda x: 3. * hcb.id_print(
            2. * x, what="here", output_stream=testing_stream))

        with hcb.outfeed_receiver(receiver_name=self._testMethodName):
            res = jit_fun1(5.)

        self.assertAllClose(6. * 5., res)
        assertMultiLineStrippedEqual(self, """
what: here
10.00""", testing_stream.output)
        testing_stream.reset()