Exemplo n.º 1
0
 def func(x):
     for i in range(1, 4):
         x = hcb.id_print(x * i,
                          what="x times i",
                          output_stream=testing_stream)
     return x
Exemplo n.º 2
0
 def func(x):
     return lax.while_loop(lambda c: c[1] < 5, lambda c:
                           (y, hcb.id_print(c[1]) + 1), (x, 1))
Exemplo n.º 3
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(tap_err, x1 + 1, what="err")
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemplo n.º 4
0
 def test_jit_error_no_consumer(self):
     # Check for errors if starting jit without a consumer active
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.jit(lambda x: hcb.id_print(x))(0)
Exemplo n.º 5
0
 def func(x):
     x1 = hcb.id_print(x, where="1", output_stream=testing_stream)
     x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream)
     return x2
Exemplo n.º 6
0
 def body(x):
     x3 = hcb.id_print(x, where="w_1", output_stream=testing_stream)
     return hcb.id_print(x3 + 1,
                         where="w_2",
                         output_stream=testing_stream)
Exemplo n.º 7
0
 def func(x):
     return lax.while_loop(
         lambda c: c[1] < 5, lambda c:
         (y, hcb.id_print(c[1], output_stream=testing_stream) + 1),
         (x, 1))
Exemplo n.º 8
0
 def func2(x):
     res = hcb.id_print(dict(a=x * 2., b=x * 3.),
                        output_stream=testing_stream)
     return res["a"] + res["b"]
Exemplo n.º 9
0
 def func(x):
     return lax.while_loop(
         lambda c: hcb.id_print(ct_cond, result=c[1]) < 5, lambda c:
         (ct_body, hcb.id_print(c[1]) + 1), (x, 1))
Exemplo n.º 10
0
 def func(x):
     return hcb.id_print(42, result=x, output_stream=testing_stream)
Exemplo n.º 11
0
 def func(x):
     return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond),
                           lambda c: (ct_body, hcb.id_print(c[1]) + 1.),
                           (x, np.float32(1.)))
Exemplo n.º 12
0
 def sum(x, y):
     return hcb.id_print(x + y, output_stream=testing_stream)
Exemplo n.º 13
0
 def test_jit_several_together(self):
     arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
     api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
         arg, jnp.ones(100, dtype=jnp.int32))
Exemplo n.º 14
0
 def func(x):
     return lax.scan(lambda c, a: (hcb.id_print(c), y), (1, 2), x)
Exemplo n.º 15
0
 def func(x):
     return 2. * hcb.id_print(
         x * 3., what="x * 3", output_stream=testing_stream)
Exemplo n.º 16
0
 def func2(x):
     x1, y1 = hcb.id_print((x * 2., x * 3.),
                           output_stream=testing_stream)
     return x1 + y1
Exemplo n.º 17
0
 def func(x):
     y = hcb.id_print(x * 2.,
                      what="x * 2",
                      output_stream=testing_stream)
     return x * (y * 3.)
Exemplo n.º 18
0
 def func2(x):
     x1 = hcb.id_print((x * 2., x * 3.),
                       result=x * 4.,
                       output_stream=testing_stream)
     return x1
Exemplo n.º 19
0
 def func(y):
     # x is not mapped, y is mapped
     _, y = hcb.id_print((x, y), output_stream=testing_stream)
     return x + y
Exemplo n.º 20
0
 def func_nested(x):
     x2 = hcb.id_print(x + 1,
                       where="nested",
                       output_stream=testing_stream)
     return x2
Exemplo n.º 21
0
 def test_pmap_error_no_receiver(self):
     # Check for errors if starting jit without a consumer active
     vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.pmap(lambda x: hcb.id_print(x))(vargs)
Exemplo n.º 22
0
 def test_jit_several_together(self):
     arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
             arg, jnp.ones(100, dtype=jnp.int32))
Exemplo n.º 23
0
 def padded_sum(x):
     return jnp.sum(
         hcb.id_print(x, what="x", output_stream=testing_stream))
Exemplo n.º 24
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(hcb._unknown_testing_consumer, x1 + 1, what="err")
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemplo n.º 25
0
 def func(x, z):
     return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)), z,
                     lambda a: (hcb.id_print(a), y))
Exemplo n.º 26
0
 def func(x):
     x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream)
     x2 = hcb.id_tap(hcb._end_consumer,
                     result=x1 + 1)  # Will end the consumer loop
     x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream)
     return x3
Exemplo n.º 27
0
 def func(x):
     for i in range(5):
         x = hcb.id_print(x * i, what="x times i")
     return x