def test_xla_computation(self): # these tests basically check the examples in the xla_computation docstring def h(x): return np.sin(np.cos(x)) c = api.xla_computation(h)(2.) self.assertIn('cosine', c.GetHloText()) self.assertIn('sine', c.GetHloText()) def f(x): return x - lax.psum(x, 'i') axis_env = [('i', 4)] c = api.xla_computation(f, axis_env=axis_env)(2) self.assertIn('all-reduce', c.GetHloText()) self.assertIn('replica_groups={{0,1,2,3}}', c.GetHloText()) def g(x): rowsum = lax.psum(x, 'i') colsum = lax.psum(x, 'j') allsum = lax.psum(x, ('i', 'j')) return rowsum, colsum, allsum axis_env = [('i', 4), ('j', 2)] c = api.xla_computation(g, axis_env=axis_env)(5.) self.assertIn('all-reduce', c.GetHloText()) self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.GetHloText()) self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.GetHloText()) self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.GetHloText())
def test_xla_computation_args(self): def foo(x, y, z): return x + y + z c = api.xla_computation(foo)(1., 2., 3.) self.assertEqual(len(c.GetProgramShape().parameter_shapes()), 3) c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) param_shapes = c.GetProgramShape().parameter_shapes() self.assertEqual(len(param_shapes), 1) self.assertEqual(param_shapes[0].xla_element_type(), xb.xla_client.PrimitiveType.TUPLE)
def test_scan_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(c, x): x3 = hcb.id_print(x, where="s_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, x3 + 1, lambda x: hcb.id_print( x, where="s_t", output_stream=testing_stream), x3 + 1, lambda x: hcb.id_print(-1, where="s_f", result=x, output_stream=testing_stream)) return (c, hcb.id_print(x4, where="s_2", output_stream=testing_stream)) _, x10 = lax.scan(body, x2, jnp.arange(3)) res = hcb.id_print(x10, where="10", output_stream=testing_stream) return res logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): if with_jit: func = api.jit(func) res = func(1) self.assertAllClose(jnp.array([1, 2, 3]), res, check_dtypes=True) assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: s_1 0 where: s_t 1 where: s_2 1 where: s_1 1 where: s_f -1 where: s_2 2 where: s_1 2 where: s_t 3 where: s_2 3 where: 10 [1 2 3]""", testing_stream.output) testing_stream.reset()
def test_jit_interleaving(self): # Several jit's without data dependencies; they may interfere count = 0 # Count tap invocations nr_arrays = 5 def tap_func(arg, **kwargs): nonlocal count assert len(arg) == nr_arrays count += 1 # This is the function that we'll run multiple times def func(x, count): for i in range(count): x = hcb.id_tap(tap_func, [x + i for i in range(nr_arrays)], i=i)[-1] return x with hcb.outfeed_receiver(receiver_name=self._testMethodName): x = jnp.array(1, dtype=np.int32) res = 0 for i in range(10): # No dependencies between the jit invocations res += api.jit(lambda x: func(x, 10))(x) logging.warning( "%s: %s", self._testMethodName, api.xla_computation(lambda x: func(x, 5))(1).as_hlo_text()) self.assertEqual(100, count)
def test_cond(self, with_jit=False): """A conditional""" def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, x2 + 1, lambda x: hcb.id_print( x, where="cond_t", output_stream=testing_stream), x2 + 1, lambda x: hcb.id_print( -1, where="cond_f", result=x, output_stream=testing_stream) ) x5 = hcb.id_print(x4 + 1, where="end", output_stream=testing_stream) return x5 logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) transform = api.jit if with_jit else lambda f: f with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(4, transform(func)(1)) assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: cond_f -1 where: end 4""", testing_stream.output) testing_stream.reset()
def test_jit_while_pred_printing(self): """While with printing in the conditional.""" raise SkipTest("Not yet implemented") #TODO: implement printing inside conditional def func(x): x1 = hcb.id_print(x, where="1") def body(x): x3 = hcb.id_print(x, where="w_1", output_stream=testing_stream) return hcb.id_print(x3 + 1, where="w_2", output_stream=testing_stream) x10 = lax.while_loop( lambda x: hcb.id_print( x < 10, where="w_p", output_stream=testing_stream), body, x1) res = hcb.id_print(x10, where="10", output_stream=testing_stream) return res logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(10, api.jit(func)(1)) assertMultiLineStrippedEqual(self, """ """, testing_stream.output) testing_stream.reset()
def test_jit_nested(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) def func_nested(x): x2 = hcb.id_print(x + 1, where="nested", output_stream=testing_stream) return x2 x3 = api.jit(func_nested)(x1) return hcb.id_print(x3 + 1, where="3", output_stream=testing_stream) logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(3, api.jit(func)(1)) assertMultiLineStrippedEqual( self, """ where: 1 1 where: nested 2 where: 3 3""", testing_stream.output) testing_stream.reset()
def test_xla_computation_instantiate_constant_outputs(self): def f(): return np.zeros((3, 4)) xla_comp = api.xla_computation(f, instantiate_const_outputs=True)() out_shape, = xla_comp.GetReturnValueShape().tuple_shapes() self.assertEqual(out_shape.dimensions(), (3, 4))
def test_jit_nested_cond_no_print(self): """A nested conditional, without any prints""" raise SkipTest("skip this") @api.jit def cfun(x): return lax.cond( lax.lt(x, 2), lambda x: x, lambda x: lax.cond(x < 5, 3, lambda x: x, 4, lambda y: y), x) print(self._testMethodName, api.xla_computation(cfun)(1).as_hlo_text()) cfun(1)
def test_while_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(x): x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, x3 + 1, lambda x: hcb.id_print( x, where="w_b_t", output_stream=testing_stream), x3 + 1, lambda x: hcb.id_print(-1, where="w_b_f", result=x, output_stream=testing_stream)) return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream) x10 = lax.while_loop(lambda x: x <= 3, body, x2) res = hcb.id_print(x10, where="end", output_stream=testing_stream) return res logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) transform = api.jit if with_jit else lambda f: f with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(4, transform(func)(1)) assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: w_b_1 2 where: w_b_t 3 where: w_b_2 3 where: w_b_1 3 where: w_b_f -1 where: w_b_2 4 where: end 4""", testing_stream.output) testing_stream.reset()
def testIssue810(self): def loss(A): def step(x, i): return np.matmul(A, x), None init_x = np.zeros(A.shape[-1:]) last_x, _ = lax.scan(step, init_x, np.arange(10)) return np.sum(last_x) A = np.zeros((3, 3)) # The second DUS was unnecessarily replicating A across time. # We check XLA because _scan_impl is "underneath" the jaxpr language. s = str(api.xla_computation(api.grad(loss))(A).GetHloText()) assert s.count("dynamic-update-slice(") < 2
def test_jit_simple(self): jit_fun1 = api.jit(lambda x: 3. * hcb.id_print( 2. * x, what="here", output_stream=testing_stream)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(jit_fun1)(5.).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): res = jit_fun1(5.) self.assertAllClose(6. * 5., res, check_dtypes=True) assertMultiLineStrippedEqual(self, """ what: here 10.00""", testing_stream.output) testing_stream.reset()
def test_jit_sequence1(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.info("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(2, api.jit(func)(1)) assertMultiLineStrippedEqual(self, """ where: 1 1 where: 2 2""", testing_stream.output) testing_stream.reset()
def test_staging_out_multi_replica(self): def f(x): return api.pmap(np.mean)(x) xla_comp = api.xla_computation(f) xla_comp(np.arange(8)).GetHloText() # doesn't crash