def test_check_jaxpr_cond_invalid(self): jaxpr = make_jaxpr(lambda x: lax.switch(0, [jnp.sin, jnp.cos], x))( 1.).jaxpr cond = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cond') cond.params['branches'][0].jaxpr.invars = () self.assertRaisesRegex( core.JaxprTypeError, 'cond branch 0 takes 0 inputs, branch 1 takes 1', lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_loop(self): def f(x): _, y = lax.while_loop(lambda s: s[0] < 0., lambda s: (jnp.sin(s[0]), jnp.cos(s[1])), (x, x)) return y + 1. jaxpr = make_jaxpr(f)(1.).jaxpr assert jaxpr.eqns[0].outvars[0] is core.dropvar core.check_jaxpr(jaxpr)
def test_jaxpr_undefined_eqn_invar(self): jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos') cos.invars[0] = core.gensym([jaxpr], suffix='_test')(cos.invars[0].aval) self.assertRaisesRegex( core.JaxprTypeError, r"Variable '.+_test' not defined\n\nin equation:", lambda: core.check_jaxpr(jaxpr))
def test_jaxpr_dropvar_from_cond(self): def f(x): _, y = lax.cond(x < 0., lambda x: (jnp.sin(x), x + 1.), lambda x: (jnp.cos(x), x + 2.), x) return y jaxpr = make_jaxpr(f)(1.).jaxpr assert jaxpr.eqns[-1].outvars[0] is core.dropvar core.check_jaxpr(jaxpr)
def test_check_jaxpr_scan_correct(self): def f(c, x): b = jnp.cos(jnp.sum(jnp.sin(x)) + jnp.sum(jnp.cos(c))) c = jnp.sin(c * b) return c, b xs = jnp.ones((5, 3)) c = jnp.ones(4) jaxpr = make_jaxpr(partial(lax.scan, f))(c, xs).jaxpr core.check_jaxpr(jaxpr)
def test_const(self): def fun(x): return (x, 1., np.zeros(1)) jaxpr = api.make_jaxpr(fun)(0.) self.assertMultiLineStrippedEqual(str(jaxpr), """ { lambda b ; ; a. let in [a, 1.0, b] } """)
def testNormalize(self): def f(x): return x / x.sum(0) x = onp.arange(4.) expected = f(x) ans = _parallelize(f)(x) self.assertAllClose(ans, expected, check_dtypes=False) jaxpr = make_jaxpr(_parallelize(f))(x) self.assertIn('psum', repr(jaxpr))
def test_jaxpr_dropvar_from_jit_call(self): def inner(x): return x + 1, x + 2 def f(x): _, y = jit(inner)(x) return y + 3 jaxpr = make_jaxpr(f)(1).jaxpr assert jaxpr.eqns[0].outvars[0] is core.dropvar core.check_jaxpr(jaxpr)
def testDotGeneralContractAndBatch(self, lhs_shape, rhs_shape, dtype, dimension_numbers, bdims): rng = jtu.rand_small(self.rng()) dot = partial(lax.dot_general, dimension_numbers=dimension_numbers) self._CheckBatching(dot, 5, bdims, (lhs_shape, rhs_shape), (dtype, dtype), rng) # Checks that batching didn't introduce any transposes or broadcasts. jaxpr = api.make_jaxpr(dot)(np.zeros(lhs_shape, dtype), np.zeros(rhs_shape, dtype)) for eqn in jtu.iter_eqns(jaxpr.jaxpr): self.assertFalse(eqn.primitive in ["transpose", "broadcast"])
def testSelect(self): pfun, axis_name = _papply(lax.select, 5, in_axes=(None, 0, None)) p = onp.arange(15).reshape((5, 3)) % 4 == 1 t = onp.ones((5, 3)) f = onp.zeros((5, 3)) jaxpr = make_jaxpr(pfun)(p, t[0], f) def expected_spmd(p, t, f): return lax.select( lax_parallel.psplit_like(p, t, axis_name), t, lax_parallel.psplit_like(f, t, axis_name)) expected_jaxpr = make_jaxpr(expected_spmd)(p, t[0], f) assert repr(jaxpr) == repr(expected_jaxpr) ans = _serial_pmap(pfun, axis_name, in_axes=(None, 0, None))(p, t, f) expected = lax.select(p, t, f) self.assertAllClose(ans, expected, check_dtypes=True)
def test_grad_simple(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * hcb.id_print( y * 3., what="y * 3", output_stream=testing_stream) grad_func = api.grad(func) assertMultiLineStrippedEqual( self, """ { lambda ; a. let b = mul 1.00 a c d = id_tap[ arg_treedef=* func=_print nr_untapped=1 transforms=(('jvp',), ('transpose',)) what=y * 3 ] b 0.00 e = mul c 3.00 f g = id_tap[ arg_treedef=* func=_print nr_untapped=1 transforms=(('jvp',), ('transpose',)) what=x * 2 ] e 0.00 h = mul f 2.00 i = mul a 2.00 j = id_tap[ arg_treedef=* func=_print nr_untapped=0 what=x * 2 ] i k = mul j 3.00 l = id_tap[ arg_treedef=* func=_print nr_untapped=0 what=y * 3 ] k m = mul 1.00 l n = add_any h m in (n,) }""", str(api.make_jaxpr(grad_func)(5.))) with hcb.outfeed_receiver(): res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 what: y * 3 30.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3 5.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00""", testing_stream.output) testing_stream.reset()
def test_while_cond(self, with_jit=False): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) x2 = hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) def body(x): x3 = hcb.id_print(x, where="w_b_1", output_stream=testing_stream) x4 = lax.cond( x % 2 == 0, x3 + 1, lambda x: hcb.id_print( x, where="w_b_t", output_stream=testing_stream), x3 + 1, lambda x: hcb.id_print(-1, where="w_b_f", result=x, output_stream=testing_stream)) return hcb.id_print(x4, where="w_b_2", output_stream=testing_stream) x10 = lax.while_loop(lambda x: x <= 3, body, x2) res = hcb.id_print(x10, where="end", output_stream=testing_stream) return res logging.warning("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.warning("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) transform = api.jit if with_jit else lambda f: f with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(4, transform(func)(1)) assertMultiLineStrippedEqual( self, """ where: 1 1 where: 2 2 where: w_b_1 2 where: w_b_t 3 where: w_b_2 3 where: w_b_1 3 where: w_b_f -1 where: w_b_2 4 where: end 4""", testing_stream.output) testing_stream.reset()
def assertRewrite(self, expected: str, func: Callable, args: Sequence, has_input_token=True, has_output_token=True): """Check that the rewrite of func(*args) matches expected.""" jaxpr = api.make_jaxpr(func)(*args) # TODO: re-enable when we change the host_callback rewriter #rewritten = hcb._rewrite_typed_jaxpr(jaxpr, # has_input_token, has_output_token) #assertMultiLineStrippedEqual(self, expected, str(rewritten)) del jaxpr
def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, dimension_numbers, rng_factory): rng = rng_factory(self.rng()) lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot_general = partial(lax.dot_general, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot_general, (lhs, rhs), order=2, modes=["fwd", "rev"]) # check that precision config is preserved result, pullback = api.vjp(dot_general, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def assertRewrite(self, expected: str, func: Callable, args: Sequence, has_input_token=True, has_output_token=True): """Check that the rewrite of func(*args) matches expected.""" jaxpr = api.make_jaxpr(func)(*args) assertMultiLineStrippedEqual( self, expected, str( hcb._rewrite_typed_jaxpr(jaxpr, has_input_token, has_output_token)[0]))
def testDotGrad(self, lhs_shape, rhs_shape, dtype, rng_factory): rng = rng_factory(self.rng()) tol = {onp.float16: 1e-1, onp.float32: 1e-4} lhs = rng(lhs_shape, dtype) rhs = rng(rhs_shape, dtype) dot = partial(lax.dot, precision=lax.Precision.HIGHEST) check_grads_bilinear(dot, (lhs, rhs), order=2, modes=["fwd", "rev"], atol=tol, rtol=tol) # check that precision config is preserved result, pullback = api.vjp(dot, lhs, rhs) gresult = lax.zeros_like_array(result) s = str(api.make_jaxpr(pullback)(gresult)) assert "precision=HIGHEST" in s
def testNestedBatchingMatMat(self): matvec = vmap(np.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = onp.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = onp.dot(A, B) self.assertAllClose(ans, expected, check_dtypes=False) jaxpr = make_jaxpr(matmat)(A, B) self.assertEqual(len(jaxpr.eqns), 1)
def test_jarrett_jvps(self): def f1(x): return np.sin(np.sin(np.sin(x))) f2 = api.jarrett(f1) for x in [3., onp.array([2., 3., 4.])]: self.assertAllClose(f1(x), f2(x), check_dtypes=True) _, f1_vjp = api.vjp(f1, x) _, f2_vjp = api.vjp(f2, x) self.assertAllClose(f1_vjp(x), f2_vjp(x), check_dtypes=True) jaxpr2 = api.make_jaxpr(f2_vjp)(x) assert len(jaxpr2.constvars) == 1
def test_jvp(self): jvp_fun1 = lambda x, xt: api.jvp(fun1, (x, ), (xt, )) assertMultiLineStrippedEqual( self, """ { lambda ; a b. let c = mul a 2.00 d = id_tap[ arg_treedef=* func=_print nr_untapped=0 what=a * 2 ] c e = mul d 3.00 f g = id_tap[ arg_treedef=* func=_print nr_untapped=1 what=y * 3 ] e d h = mul g g i = mul b 2.00 j k = id_tap[ arg_treedef=* func=_print nr_untapped=1 transforms=('jvp',) what=a * 2 ] i d l = mul j 3.00 m n o = id_tap[ arg_treedef=* func=_print nr_untapped=2 transforms=('jvp',) what=y * 3 ] l j f p = mul n g q = mul g n r = add_any p q in (h, r) }""", str(api.make_jaxpr(jvp_fun1)(jnp.float32(5.), jnp.float32(0.1)))) with hcb.outfeed_receiver(): res_primals, res_tangents = jvp_fun1(jnp.float32(5.), jnp.float32(0.1)) self.assertAllClose(100., res_primals, check_dtypes=False) self.assertAllClose(4., res_tangents, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: a * 2 10.00 transforms: ('jvp',) what: a * 2 0.20 what: y * 3 30.00 transforms: ('jvp',) what: y * 3 0.60""", testing_stream.output) testing_stream.reset()
def test_jarrett_jvps2(self): def f1(x, y): return np.sin(x) * np.cos(y) * np.sin(x) * np.cos(y) f2 = api.jarrett(f1) # TODO(mattjj): doesn't work for (3., onp.array([4., 5.])) for x, y in [(3., 4.), (onp.array([5., 6.]), onp.array([7., 8.]))]: self.assertAllClose(f1(x, y), f2(x, y), check_dtypes=True) _, f1_vjp = api.vjp(f1, x, y) _, f2_vjp = api.vjp(f2, x, y) self.assertAllClose(f1_vjp(y), f2_vjp(y), check_dtypes=True) jaxpr2 = api.make_jaxpr(f2_vjp)(y) assert len(jaxpr2.constvars) == 2
def testNestedBatchingMatMat(self): matvec = vmap(jnp.vdot, in_axes=(0, None)) matmat = vmap(matvec, in_axes=(None, 1), out_axes=1) R = np.random.RandomState(0).randn A = R(4, 3) B = R(3, 2) ans = matmat(A, B) expected = np.dot(A, B) self.assertAllClose( ans, expected, check_dtypes=False, rtol={np.float32:1e-2} if jtu.device_under_test() == "tpu" else None) jaxpr = make_jaxpr(matmat)(A, B) self.assertEqual(len(jaxpr.jaxpr.eqns), 1)
def test_partial_eval_lower(self): # this is a simplified model of a bug that arose when we first used @jit in # a jvp rule. it's in this file because we want to use make_jaxpr. @api.jit def f(a, b, c): a = lax.broadcast(a, (2, )) return lax.select(a, b, c) a = onp.ones((3, 3), dtype=onp.bool_) b = onp.ones((2, 3, 3)) c = onp.ones((2, 3, 3)) jaxpr = api.make_jaxpr(lambda b, c: f(a, b, c))(b, c) subjaxpr = next(eqn.bound_subjaxprs[0][0] for eqn in jaxpr.eqns if eqn.bound_subjaxprs) self.assertEqual(len(subjaxpr.eqns), 1)
def test_jit_sequence1(self): def func(x): x1 = hcb.id_print(x, where="1", output_stream=testing_stream) return hcb.id_print(x1 + 1, where="2", output_stream=testing_stream) logging.info("%s: %s", self._testMethodName, api.make_jaxpr(func)(1)) logging.info("%s: %s", self._testMethodName, api.xla_computation(func)(1).as_hlo_text()) with hcb.outfeed_receiver(receiver_name=self._testMethodName): self.assertEqual(2, api.jit(func)(1)) assertMultiLineStrippedEqual(self, """ where: 1 1 where: 2 2""", testing_stream.output) testing_stream.reset()
def test_with_tuple_results(self): def func2(x): x1, y1 = hcb.id_print((x * 2., x * 3.), output_stream=testing_stream) return x1 + y1 assertMultiLineStrippedEqual(self, """ { lambda ; a. let b = mul a 2.00 c = mul a 3.00 d e = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*]) func=_print ] b c f = add d e in (f,) }""", str(api.make_jaxpr(func2)(3.))) with hcb.outfeed_receiver(): self.assertEqual(3. * (2. + 3.), func2(3.)) assertMultiLineStrippedEqual(self, """ [ 6.00 9.00 ]""", testing_stream.output) testing_stream.reset()
def test_jit_constant(self): def func(x): return hcb.id_print(42, result=x, output_stream=testing_stream) assertMultiLineStrippedEqual(self, """ { lambda ; a. let b = xla_call[ backend=None call_jaxpr={ lambda ; a. let b c = id_tap[ arg_treedef=* func=_print nr_untapped=1 ] 42 a in (c,) } device=None name=func ] a in (b,) }""", str(api.make_jaxpr(api.jit(func))(5))) self.assertEqual("", testing_stream.output) with hcb.outfeed_receiver(): self.assertAllClose(5, api.jit(func)(5), check_dtypes=True) assertMultiLineStrippedEqual(self, """ 42""", testing_stream.output) testing_stream.reset()
def test_vmap_not_batched(self): x = 3. def func(y): # x is not mapped, y is mapped _, y = hcb.id_print((x, y), output_stream=testing_stream) return x + y vmap_func = api.vmap(func) vargs = jnp.array([jnp.float32(4.), jnp.float32(5.)]) assertMultiLineStrippedEqual(self, """ { lambda ; a. let b c = id_tap[ arg_treedef=PyTreeDef(tuple, [*,*]) func=_print transforms=(('batch', (None, 0)),) ] 3.00 a d = add c 3.00 in (d,) }""", str(api.make_jaxpr(vmap_func)(vargs))) with hcb.outfeed_receiver(): res_vmap = vmap_func(vargs) assertMultiLineStrippedEqual(self, """ transforms: ({'name': 'batch', 'batch_dims': (None, 0)},) [ 3.00 [4.00 5.00] ]""", testing_stream.output) testing_stream.reset()
def test_cond(self): def f(x): return lax.cond(x >= 0., x + 1., lambda xt: xt + x, x + 2., lambda xf: xf - x) jaxpr = api.make_jaxpr(f)(3.) self.assertMultiLineStrippedEqual(str(jaxpr), """ { lambda ; ; a. let b = ge a 0.0 c = add a 1.0 d = add a 2.0 e = cond[ false_jaxpr={ lambda ; ; b a. let c = sub a b in [c] } false_nconsts=1 true_jaxpr={ lambda ; ; b a. let c = add a b in [c] } true_nconsts=1 ] b a c a d in [e] } """)
def test_mask(self): # TODO(necula) raise SkipTest("masking has regressed") @partial(api.mask, in_shapes=['n'], out_shape='') def padded_sum(x): return jnp.sum(hcb.id_print(x, what="x", output_stream=testing_stream)) args = [jnp.arange(4)], dict(n=np.int64(2)) assertMultiLineStrippedEqual(self, """ { lambda c f ; a b. let d = lt c b e = id_tap[ func=_print logical_shapes=[(Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)>,)] transforms=('mask',) what=x ] a g = select d e f h = reduce_sum[ axes=(0,) ] g in (h,) }""", str(api.make_jaxpr(padded_sum)(*args))) _ = padded_sum(*args) self.assertMultiLineStrippedEqual(""" logical_shapes: [(2,)] transforms: ('mask',) what: x [0 1 2 3] """, testing_stream.output) testing_stream.reset()
def test_grad_double(self): def func(x): y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream) return x * (y * 3.) grad_func = api.grad(api.grad(func)) with hcb.outfeed_receiver(): assertMultiLineStrippedEqual( self, """ { lambda ; a. let in (12.00,) }""", str(api.make_jaxpr(grad_func)(5.))) # Just making the Jaxpr invokes the id_print twiceonce assertMultiLineStrippedEqual( self, """ transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00""", testing_stream.output) testing_stream.reset() res_grad = grad_func(jnp.float32(5.)) self.assertAllClose(12., res_grad, check_dtypes=False) assertMultiLineStrippedEqual( self, """ what: x * 2 10.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 15.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}, {'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 2.00 transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2 3.00""", testing_stream.output) testing_stream.reset()
def test_eval(self): assertMultiLineStrippedEqual(self, """ { lambda ; a. let b = mul a 2.00 c = id_tap[ arg_treedef=* func=_print what=a * 2 ] b d = mul c 3.00 e f = id_tap[ arg_treedef=* func=_print nr_untapped=1 what=y * 3 ] d c g = integer_pow[ y=2 ] f in (g,) }""", str(api.make_jaxpr(fun1)(5.))) self.assertEqual("", testing_stream.output) with hcb.outfeed_receiver(): self.assertAllClose((5. * 2.) ** 2, fun1(5.), check_dtypes=True) assertMultiLineStrippedEqual(self, """ what: a * 2 10.00 what: y * 3 30.00""", testing_stream.output) testing_stream.reset()