def _source_info_to_location( source_info: source_info_util.SourceInfo) -> ir.Location: frame = source_info_util.user_frame(source_info) if frame is None: return ir.Location.unknown() return ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1)
def test_op_metadata_named(self): self.skipTest("include_xla_op_metadata not yet enabled") # Calling a jax.named_call # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_callee(x): return jnp.cos(x) def f_caller(x): y = jnp.tanh(x) z = jax.named_call(f_callee, name="callee")(y) return jnp.sin(z) x = np.ones((2, 3), np.float32) self.CheckOpMetadata(f_caller, x, [ tf_test_util.OpMetadataGraph(tf_type="Tanh", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_caller)/tanh", op_type="tanh"), tf_test_util.OpMetadataGraph( tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_caller)/named(callee)/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_caller)/sin", op_type="sin"), ])
def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: str = "", ) -> xc.OpMetadata: eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params, JaxprPpContext())) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: str = "") -> ir.Location: eqn_str = name_stack + core.str_eqn_compact(primitive.name, params) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. return loc
def test_op_metadata_simple(self): self.skipTest("include_xla_op_metadata not yet enabled") # A simple example # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_simple(x): return jnp.sin(x) x = np.ones((2, 3), np.float32) self.CheckOpMetadata(f_simple, x, [ tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 2, op_name="jax2tf(f_simple)/sin", op_type="sin") ])
def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: Union[str, source_info_util.NameStack] = "", ) -> xc.OpMetadata: if config.jax_experimental_name_stack: eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params) else: assert isinstance(name_stack, str) eqn_str = name_stack + str_eqn_compact(primitive.name, params) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def test_op_metadata_batched_while(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) @jax.vmap def f_while(x): def body_fun(carry): new_carry = jnp.sin(carry) # We look for "sin" in the graph return new_carry _, carry = lax.while_loop( lambda carry: jnp.all(carry <= x ), # We look for "le" in the graph body_fun, x) return carry shape = (3, 2) x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) jax_comp = jax.xla_computation(f_while)(x) backend = jax.lib.xla_bridge.get_backend() modules = backend.compile(jax_comp).hlo_modules() jax_opt_hlo = modules[0].to_string() print(f"JAX OPT HLO = {jax_opt_hlo}") self.CheckOpMetadata(f_while, x, [ tf_test_util.OpMetadataGraph( tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 4, op_name="jax2tf(f_while)/while/body/sin", op_type="sin"), tf_test_util.OpMetadataGraph( tf_type="LessEqual", source_file=__file__, source_line=user_frame.line_num + 8, op_name="jax2tf(f_while)/while/body_pred/le", op_type="le"), ])
def test_op_metadata_while_and_cond(self): self.skipTest("include_xla_op_metadata not yet enabled") # An example with while and cond # The user_frame is used to compute line numbers for ops in the test. user_frame = source_info_util.user_frame(source_info_util.current()) def f_while_cond(x): def body_fun(i_acc): i, acc = i_acc return (i + 1, (jnp.cos(acc) + lax.cond(jnp.mod(i, 2) == 0, lambda acc: jnp.sin(acc), lambda acc: acc, acc))) _, acc = lax.while_loop( lambda i_acc: i_acc[0] <= 5, body_fun, (0, x)) return acc x = np.ones((2, 3), np.float32) self.CheckOpMetadata( f_while_cond, x, [tf_test_util.OpMetadataGraph(tf_type="Cos", source_file=__file__, source_line=user_frame.line_num + 5, op_name="jax2tf(f_while_cond)/while/body/cos", op_type="cos"), tf_test_util.OpMetadataGraph(tf_type="Sin", source_file=__file__, source_line=user_frame.line_num + 7, op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin", op_type="sin"), tf_test_util.OpMetadataGraph(tf_type="FloorMod", source_file=__file__, source_line=user_frame.line_num + 6, op_name="jax2tf(f_while_cond)/while/body/rem", op_type="rem"), ] )