def test_op_metadata_named(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jax.named_call # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_callee(x): return jnp.cos(x) def f_caller(x): y = jnp.tanh(x) z = jax.named_call(f_callee, name="callee")(y) return jnp.sin(z) x = np.ones((2, 3), np.float32) self.CheckOpMetadata(f_caller, x, [ tf_test_util.OpMetadataGraph(tf_type="Tanh", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_caller)/tanh", op_type="tanh"), tf_test_util.OpMetadataGraph( tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_caller)/named(callee)/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_caller)/sin", op_type="sin"), ])
def test_op_metadata_batched_while(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) @jax.vmap def f_while(x): def body_fun(carry): new_carry = jnp.sin(carry) # We look for "sin" in the graph return new_carry _, carry = lax.while_loop( lambda carry: jnp.all(carry <= x ), # We look for "le" in the graph body_fun, x) return carry shape = (3, 2) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) jax_comp = jax.xla_computation(f_while)(x) backend = jax.lib.xla_bridge.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() print(f"JAX OPT HLO = {jax_opt_hlo}") self.CheckOpMetadata(f_while, x, [ tf_test_util.OpMetadataGraph( tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_while)/while/body/sin", op_type="sin"), tf_test_util.OpMetadataGraph( tf_type="LessEqual", source_file=__file__, source_line=user_frame.line_num + 8, op_name="jax2tf(f_while)/while/body_pred/le", op_type="le"), ])
def test_op_metadata_while_and_cond(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_while_cond(x): def body_fun(i_acc): i, acc = i_acc return (i + 1, (jnp.cos(acc) + lax.cond(jnp.mod(i, 2) == 0, lambda acc: jnp.sin(acc), lambda acc: acc, acc))) _, acc = lax.while_loop( lambda i_acc: i_acc[0] <= 5, body_fun, (0, x)) return acc x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_while_cond, x, [tf_test_util.OpMetadataGraph(tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 5, op_name="jax2tf(f_while_cond)/while/body/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 7, op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin", op_type="sin"), tf_test_util.OpMetadataGraph(tf_type="FloorMod", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_while_cond)/while/body/rem", op_type="rem"), ] )
def test_op_metadata_simple(self): self.skipTest("include_xla_op_metadata not yet enabled") # A simple example # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_simple(x): return jnp.sin(x) x = np.ones((2, 3), np.float32) self.CheckOpMetadata(f_simple, x, [ tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_simple)/sin", op_type="sin") ])