def test_jit_while_pred_tap(self): """While with printing in the conditional.""" def func(x): x1 = hcb.id_print(x, where="1") x10 = lax.while_loop( lambda x: hcb.id_print( x < 3, where="w_p", output_stream=testing_stream), lambda x: hcb.id_print( x + 1, where="w_b", output_stream=testing_stream), x1) res = hcb.id_print(x10, where="3", output_stream=testing_stream) return res self.assertEqual(3, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: w_p True where: w_b 2 where: w_p True where: w_b 3 where: w_p False where: 3 3""", testing_stream.output) testing_stream.reset()
def test_callback_delay_barrier(self): hcb.callback_extra = lambda dev: time.sleep(2) def func(x): for i in range(1, 4): x = hcb.id_print(x * i, what="x times i", output_stream=testing_stream) return x api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) # Wait for the results hcb.barrier_wait() expected = """ what: x times i [[0. 1. 2.] [3. 4. 5.]] what: x times i [[ 0. 2. 4.] [ 6. 8. 10.]] what: x times i [[ 0. 6. 12.] [18. 24. 30.]]""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() # Call again api.jit(func)(np.arange(6, dtype=np.float32).reshape((2, 3))) hcb.barrier_wait() self.assertMultiLineStrippedEqual(expected, testing_stream.output)
def test_jit_nested(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) def func_nested(x): x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream) return x2 x3 = api.jit(func_nested)(x1) return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream) self.assertEqual(3, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: nested 2 where: 3 3""", testing_stream.output) testing_stream.reset()
def test_jit_tap_exception(self): # Simulate a tap error def tap_err(*args, **kwargs): raise NotImplementedError def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 res = api.jit(func)(0) # No error yet with self.assertRaises(hcb.TapFunctionException): hcb.barrier_wait() # Even though the receiver thread raised, the main thread should still # return 3. self.assertEqual(3, res) # We should have received all others assertMultiLineStrippedEqual(self, """ what: x1 1 what: x3 3""", testing_stream.output) testing_stream.reset()
def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * hcb.id_print( y * 3., what="y * 3", output_stream=testing_stream) grad_func = api.grad(func) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.))) res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 what: y * 3 30.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3 5.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00""", testing_stream.output) testing_stream.reset()
def test_cond(self, with_jit=False): """A conditional""" def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="cond_t", output_stream=testing_stream), lambda x: hcb.id_print( -1, where="cond_f", result=x, output_stream=testing_stream ), x2 + 1) x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream) return x5 transform = api.jit if with_jit else lambda f: f self.assertEqual(4, transform(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: cond_f -1 where: end 4""", testing_stream.output) testing_stream.reset()
def test_grad_double(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * (y * 3.) grad_func = api.grad(api.grad(func)) # Just making the Jaxpr invokes the id_print twice _ = api.make_jaxpr(grad_func)(5.) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(12., res_grad, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00""", testing_stream.output) testing_stream.reset()
def test_jit_devices(self): """Running on multiple devices.""" devices = api.local_devices() logging.info(f"{self._testMethodName}: has devices {devices}") def func(x, device_id): x1 = hcb.id_print(x, dev=str(device_id), output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) return x2 for d in devices: self.assertEqual( 112, api.jit(func, device=d, static_argnums=1)(111, d.id)) hcb.barrier_wait() logging.info( f"{self._testMethodName}: found output {testing_stream.output}") self.assertEqual(len(devices), len(re.findall(r"111", testing_stream.output))) self.assertEqual(len(devices), len(re.findall(r"112", testing_stream.output))) testing_stream.reset()
def test_pytree(self, with_jit=False): def func(x, what=""): """Returns some pytrees depending on x""" if what == "pair_1_x": return (1, x) elif what == "pair_x_2x": return (x, 2 * x) elif what == "dict": return dict(a=2 * x, b=3 * x) else: assert False tap_count = 0 def tap_func(a, what=""): nonlocal tap_count tap_count += 1 self.assertEqual(func(5, what), a) transform = api.jit if with_jit else lambda f: f for what in ("pair_1_x", "pair_x_2x", "dict"): self.assertEqual( func(10, what), transform(lambda x: hcb.id_tap(tap_func, func(x, what), result=func(x * 2, what), what=what))(5)) hcb.barrier_wait() # Wait for receivers to be done self.assertEqual(3, tap_count)
def test_vmap_while_tap_cond(self): """Vmap of while, with a tap in the conditional.""" def func(x): # like max(x, 2) x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = lax.while_loop( lambda x: hcb.id_print( x < 2, where="w_c", output_stream=testing_stream), lambda x: hcb.id_print( x + 1, where="w_b", output_stream=testing_stream), x1) res = hcb.id_print(x2, where="3", output_stream=testing_stream) return res inputs = np.arange(5, dtype=np.int32) res = api.jit(api.vmap(func))(inputs) hcb.barrier_wait() self.assertAllClose(np.array([2, 2, 2, 3, 4]), res, check_dtypes=False) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 1 [0 1 2 3 4] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [ True True False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [1 2 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [ True False False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_b [2 3 3 4 5] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: w_c [False False False False False] transforms: ({'name': 'batch', 'batch_dims': (0,)},) where: 3 [2 2 2 3 4]""", testing_stream.output) testing_stream.reset()
def test_grad_primal_unused(self): # The output of id_print is not needed for backwards pass def func(x): return 2. * hcb.id_print( x * 3., what="x * 3", output_stream=testing_stream) grad_func = api.grad(func) jaxpr = str(api.make_jaxpr(grad_func)(5.)) # Just making the Jaxpr invokes the id_print once hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ { lambda ; a. let in (6.00,) }""", jaxpr) assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) hcb.barrier_wait() self.assertAllClose(6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 3 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 3 2.00""", testing_stream.output) testing_stream.reset()
def test_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) self.assertAllClose(6. * 5., jit_fun1(5.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ what: here 10.00""", testing_stream.output) testing_stream.reset()
def test_pmap(self): vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32) pmap_fun1 = api.pmap(fun1, axis_name="i") res = pmap_fun1(vargs) hcb.barrier_wait() expected_res = jnp.stack( [fun1_equiv(2. + a) for a in range(api.local_device_count())]) self.assertAllClose(expected_res, res, check_dtypes=False)
def test_scan_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(c, x): x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="s_t", output_stream=testing_stream), lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream), x3 + 1) return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream)) _, x10 = lax.scan(body, x2, jnp.arange(3)) res = hcb.id_print(x10, where="10", output_stream=testing_stream) return res if with_jit: func = api.jit(func) res = func(1) self.assertAllClose(jnp.array([1, 2, 3]), res) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: s_1 0 where: s_t 1 where: s_2 1 where: s_1 1 where: s_f -1 where: s_2 2 where: s_1 2 where: s_t 3 where: s_2 3 where: 10 [1 2 3]""", testing_stream.output) testing_stream.reset()
def test_jit_constant(self): def func(x): return hcb.id_print(42, result=x, output_stream=testing_stream) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(api.jit(func))(5))) self.assertAllClose(5, api.jit(func)(5)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 42""", testing_stream.output) testing_stream.reset()
def test_with_dict_results(self): def func2(x): res = hcb.id_print(dict(a=x * 2., b=x * 3.), output_stream=testing_stream) return res["a"] + res["b"] self.assertEqual(3. * (2. + 3.), func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ { a=6.00 b=9.00 }""", testing_stream.output) testing_stream.reset()
def test_eval(self): # TODO: renable jaxpr golden tests when changing host_callback #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(fun1)(5.))) self.assertAllClose((5. * 2.)**2, fun1(5.)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 what: y * 3 30.00""", testing_stream.output) testing_stream.reset()
def test_vmap(self): vmap_fun1 = api.vmap(fun1) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_fun1)(vargs))) vmap_fun1(vargs) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)},) what: a * 2 [ 8.00 10.00] transforms: ({'name': 'batch', 'batch_dims': (0, 0)},) what: y * 3 [24.00 30.00]""", testing_stream.output) testing_stream.reset()
def test_with_result(self): def func2(x): x1 = hcb.id_print((x * 2., x * 3.), result=x * 4., output_stream=testing_stream) return x1 self.assertEqual(3. * 4., func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) testing_stream.reset()
def test_with_tuple_results(self): def func2(x): x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream) return x1 + y1 #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(func2)(3.))) self.assertEqual(3. * (2. + 3.), func2(3.)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) testing_stream.reset()
def test_while_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(x): x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, lambda x: hcb.id_print( x, where="w_b_t", output_stream=testing_stream), lambda x: hcb.id_print(-1, where="w_b_f", result=x, output_stream=testing_stream), x3 + 1) return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream) x10 = lax.while_loop(lambda x: x <= 3, body, x2) res = hcb.id_print(x10, where="end", output_stream=testing_stream) return res transform = api.jit if with_jit else lambda f: f self.assertEqual(4, transform(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: w_b_1 2 where: w_b_t 3 where: w_b_2 3 where: w_b_1 3 where: w_b_f -1 where: w_b_2 4 where: end 4""", testing_stream.output) testing_stream.reset()
def test_while(self): """Executing while, even without JIT uses compiled code""" y = jnp.ones(5) # captured const def func(x): return lax.while_loop( lambda c: c[1] < 5, lambda c: (y, hcb.id_print(c[1], output_stream=testing_stream) + 1), (x, 1)) func(y) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ 1 2 3 4""", testing_stream.output) testing_stream.reset()
def test_jvp(self): jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, )) #assertMultiLineStrippedEqual(self, "") res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 transforms: ({'name': 'jvp'},) what: a * 2 0.20 what: y * 3 30.00 transforms: ({'name': 'jvp'},) what: y * 3 0.60""", testing_stream.output) testing_stream.reset()
def test_jit_sequence1(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.info("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) self.assertEqual(2, api.jit(func)(1)) hcb.barrier_wait() assertMultiLineStrippedEqual(self, """ where: 1 1 where: 2 2""", testing_stream.output) testing_stream.reset()
def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(vmap_func)(vargs))) _ = vmap_func(vargs) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset()
def test_jit2(self): """A sequence of JIT.""" def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) return x2 self.assertEqual(2, api.jit(func)(1)) self.assertEqual(11, api.jit(func)(10)) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: 1 10 where: 2 11""", testing_stream.output) testing_stream.reset()
def test_eval_tap_exception(self): # Simulate a tap error def tap_err(*args, **kwargs): raise NotImplementedError def func(x): x1 = hcb.id_print(x + 1, what="x1", output_stream=testing_stream) x2 = hcb.id_tap(tap_err, x1 + 1, what="err") x3 = hcb.id_print(x2 + 1, what="x3", output_stream=testing_stream) return x3 with self.assertRaises(hcb.TapFunctionException): func(0) hcb.barrier_wait() # We should have received everything before the error assertMultiLineStrippedEqual(self, """ what: x1 1 what: x3 3""", testing_stream.output) testing_stream.reset()
def test_double_vmap(self): # A 2D tensor with x[i, j] = i + j using 2 vmap def sum(x, y): return hcb.id_print(x + y, output_stream=testing_stream) def sum_rows(xv, y): return api.vmap(sum, in_axes=(0, None))(xv, y) def sum_all(xv, yv): return api.vmap(sum_rows, in_axes=(None, 0))(xv, yv) xv = jnp.arange(5, dtype=np.int32) yv = jnp.arange(3, dtype=np.int32) #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(sum_all)(xv, yv))) _ = sum_all(xv, yv) hcb.barrier_wait() assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'batch', 'batch_dims': (0,)}, {'name': 'batch', 'batch_dims': (0,)}) [[0 1 2 3 4] [1 2 3 4 5] [2 3 4 5 6]]""", testing_stream.output) testing_stream.reset()
def test_multiple_tap(self, concurrent=False): """Call id_tap multiple times, concurrently or in sequence. """ if concurrent and jtu.device_under_test() == "gpu": # TODO(necula): it seems that on GPU if multiple host threads run # a jit computation, the mutliple computations are interleaved on the # GPU. This can result in the outfeed trains being interleaved, which # will trigger an error. The solution is to fix on GPU the receiving # logic so that we can outfeed the train as one tuple, and receive it # one piece as a time. Then the trains should be atomic. # See also b/160692602. raise SkipTest("concurrent id_tap not supported on GPU") received = set() count = 5 def pause_tap(idx, **kwargs): received.add(int(idx)) logging.info(f"Starting do_tap {idx}. Sleeping 1sec ...") time.sleep(0.3) logging.info(f"Finish do_tap {idx}") def do_tap(idx): api.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx) if concurrent: threads = [ threading.Thread(name=f"enqueue_tap_{idx}", target=do_tap, args=(idx, )) for idx in range(count) ] [t.start() for t in threads] [t.join() for t in threads] else: for idx in range(count): do_tap(idx) hcb.barrier_wait() self.assertEqual(received, set(range(count)))
def test_jit_interleaving(self): # Several jit's without data dependencies; they may interfere count = 0 # Count tap invocations nr_arrays = 5 def tap_func(arg, **_): nonlocal count assert len(arg) == nr_arrays count += 1 # This is the function that we'll run multiple times def func(x, count): for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x x = jnp.array(1, dtype=np.int32) res = 0 for _ in range(10): # No dependencies between the jit invocations res += api.jit(lambda x: func(x, 10))(x) hcb.barrier_wait() self.assertEqual(100, count)