def test_while_batched(self, with_function=True): """A while with a single carry""" with jax_to_tf.enable_jit(): def product(x, y): # Equivalent to "x * y" implemented as: # res = 0. # for(i=0; i < y; i++) # res += x return lax.while_loop( lambda idx_carry: idx_carry[0] < y, lambda idx_carry: (idx_carry[0] + 1, idx_carry[1] + x), (0, 0.)) # We use vmap to compute result[i, j] = i * j xs = np.arange(4, dtype=np.int32) ys = np.arange(5, dtype=np.int32) def product_xs_y(xs, y): return jax.vmap(product, in_axes=(0, None))(xs, y) def product_xs_ys(xs, ys): return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys) f_jax = product_xs_ys f_tf = jax_to_tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) res_jax = f_jax(xs, ys) res_tf = f_tf(xs, ys) for r_tf, r_jax in zip(res_tf, res_jax): np.testing.assert_allclose(r_tf, r_jax)
def test_while(self, with_function=False): # Some constants to capture in the conditional branches cond_const = np.ones(3, dtype=np.float32) body_const1 = np.full_like(cond_const, 1.) body_const2 = np.full_like(cond_const, 2.) def func(x): # Equivalent to: # c = [1, 1, 1] # for(i=0; i < 3; i++) # c += [1, 1, 1] + [2, 2, 2] # # The function is set-up so that it captures constants in the # body of the functionals. This covers some cases in the representation # of the lax.while primitive. def cond(idx_carry): i, c = idx_carry return i < jnp.sum(lax.tie_in(i, cond_const)) # Capture cond_const def body(idx_carry): i, c = idx_carry return (i + 1, c + body_const1 + body_const2) return lax.while_loop(cond, body, (0, x)) with jax_to_tf.enable_jit(): self.ConvertAndCompare(func, cond_const, with_function=with_function)
def test_cond_multiple_results(self, with_function=False): def f_jax(pred, x): return lax.cond(pred, lambda t: (t + 1., 1.), lambda f: (f + 2., 2.), x) with jax_to_tf.enable_jit(): self.ConvertAndCompare(f_jax, True, 1., with_function=with_function) self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
def test_cond(self, with_function=False): def f_jax(pred, x): return lax.cond(pred, lambda t: t + 1., lambda f: f, x) with jax_to_tf.enable_jit(): self.ConvertAndCompare(f_jax, True, 1., with_function=with_function) self.ConvertAndCompare(f_jax, False, 1., with_function=with_function)
def test_while_single_carry(self, with_function=False): """A while with a single carry""" def func(x): # Equivalent to: # for(i=x; i < 4; i++); return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x) with jax_to_tf.enable_jit(): self.ConvertAndCompare(func, 0, with_function=with_function)
def test_cond(self, with_function=False): with jax_to_tf.enable_jit(): def f_jax(pred, x): return lax.cond(pred, lambda t: t + 1., lambda f: f, x) f_tf = jax_to_tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) np.testing.assert_allclose(f_tf(True, 1.), f_jax(True, 1.)) np.testing.assert_allclose(f_tf(False, 1.), f_jax(False, 1.))
def test_scan(self, with_function=False): def f_jax(xs, ys): body_const = np.ones((2, ), dtype=np.float32) # Test constant capture def body(res0, inputs): x, y = inputs return res0 + x * y, body_const return lax.scan(body, 0., (xs, ys)) arg = np.arange(10, dtype=np.float32) with jax_to_tf.enable_jit(): self.ConvertAndCompare(f_jax, arg, arg, with_function=with_function)
def test_while_single_carry(self, with_function=False): """A while with a single carry""" with jax_to_tf.enable_jit(): def func(x): # Equivalent to: # for(i=x; i < 4; i++); return lax.while_loop(lambda c: c < 4, lambda c: c + 1, x) f_jax = func f_tf = jax_to_tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) res_jax = f_jax(0) res_tf = f_tf(0) np.testing.assert_allclose(res_jax, res_tf)
def test_while_batched(self, with_function=True): """A while with a single carry""" def product(x, y): # Equivalent to "x * y" implemented as: # res = 0. # for(i=0; i < y; i++) # res += x return lax.while_loop(lambda idx_carry: idx_carry[0] < y, lambda idx_carry: (idx_carry[0] + 1, idx_carry[1] + x), (0, 0.)) # We use vmap to compute result[i, j] = i * j xs = np.arange(4, dtype=np.int32) ys = np.arange(5, dtype=np.int32) def product_xs_y(xs, y): return jax.vmap(product, in_axes=(0, None))(xs, y) def product_xs_ys(xs, ys): return jax.vmap(product_xs_y, in_axes=(None, 0))(xs, ys) with jax_to_tf.enable_jit(): self.ConvertAndCompare(product_xs_ys, xs, ys, with_function=with_function)
def test_while(self, with_function=False): with jax_to_tf.enable_jit(): # Some constants to capture in the conditional branches cond_const = np.ones(3, dtype=np.float32) body_const1 = np.full_like(cond_const, 1.) body_const2 = np.full_like(cond_const, 2.) def func(x): # Equivalent to: # c = [1, 1, 1] # for(i=0; i < 3; i++) # c += [1, 1, 1] + [2, 2, 2] # # The function is set-up so that it captures constants in the # body of the functionals. This covers some cases in the representation # of the lax.while primitive. def cond(idx_carry): i, c = idx_carry return i < jnp.sum(lax.tie_in( i, cond_const)) # Capture cond_const def body(idx_carry): i, c = idx_carry return (i + 1, c + body_const1 + body_const2) return lax.while_loop(cond, body, (0, x)) f_jax = func f_tf = jax_to_tf.convert(f_jax) if with_function: f_tf = tf.function(f_tf) input = cond_const res_jax = f_jax(input) res_tf = f_tf(input) for r_jax, r_tf in zip(res_jax, res_tf): np.testing.assert_allclose(r_jax, r_tf)