示例#1
0
文件: mlir.py 项目: rsepassi/jax
def _source_info_to_location(
        source_info: source_info_util.SourceInfo) -> ir.Location:
    frame = source_info_util.user_frame(source_info)
    if frame is None:
        return ir.Location.unknown()
    return ir.Location.file(xla._get_canonical_source_file(frame),
                            frame.line_num, 1)
示例#2
0
    def test_op_metadata_named(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # Calling a jax.named_call
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_callee(x):
            return jnp.cos(x)

        def f_caller(x):
            y = jnp.tanh(x)
            z = jax.named_call(f_callee, name="callee")(y)
            return jnp.sin(z)

        x = np.ones((2, 3), np.float32)

        self.CheckOpMetadata(f_caller, x, [
            tf_test_util.OpMetadataGraph(tf_type="Tanh",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 4,
                                         op_name="jax2tf(f_caller)/tanh",
                                         op_type="tanh"),
            tf_test_util.OpMetadataGraph(
                tf_type="Cos",
                source_file=__file__,
                source_line=user_frame.line_num + 2,
                op_name="jax2tf(f_caller)/named(callee)/cos",
                op_type="cos"),
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 6,
                                         op_name="jax2tf(f_caller)/sin",
                                         op_type="sin"),
        ])
示例#3
0
文件: xla.py 项目: wayfeng/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: str = "",
                     ) -> xc.OpMetadata:
  eqn_str = str(pp.text(name_stack) +
                pp_eqn_compact(primitive.name, params, JaxprPpContext()))
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info) if source_info else None
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
示例#4
0
文件: mlir.py 项目: GJBoth/jax
def _source_info_to_location(
    primitive: core.Primitive, params: Dict,
    source_info: source_info_util.SourceInfo,
    name_stack: str = "") -> ir.Location:
  eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
  frame = source_info_util.user_frame(source_info)
  if frame is None:
    loc = ir.Location.unknown()
  else:
    loc = ir.Location.file(xla._get_canonical_source_file(frame),
                           frame.line_num, 1)
  loc = ir.Location.name(eqn_str, childLoc=loc)
  # TODO(phawkins): also include primitive.name as the operator type.
  return loc
示例#5
0
    def test_op_metadata_simple(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # A simple example
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        def f_simple(x):
            return jnp.sin(x)

        x = np.ones((2, 3), np.float32)
        self.CheckOpMetadata(f_simple, x, [
            tf_test_util.OpMetadataGraph(tf_type="Sin",
                                         source_file=__file__,
                                         source_line=user_frame.line_num + 2,
                                         op_name="jax2tf(f_simple)/sin",
                                         op_type="sin")
        ])
示例#6
0
文件: xla.py 项目: John1Tang/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: Union[str, source_info_util.NameStack] = "",
                     ) -> xc.OpMetadata:
  if config.jax_experimental_name_stack:
    eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
  else:
    assert isinstance(name_stack, str)
    eqn_str = name_stack + str_eqn_compact(primitive.name, params)
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info)
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
示例#7
0
    def test_op_metadata_batched_while(self):
        self.skipTest("include_xla_op_metadata not yet enabled")
        # An example with while and cond
        # The user_frame is used to compute line numbers for ops in the test.
        user_frame = source_info_util.user_frame(source_info_util.current())

        @jax.vmap
        def f_while(x):
            def body_fun(carry):
                new_carry = jnp.sin(carry)  # We look for "sin" in the graph
                return new_carry

            _, carry = lax.while_loop(
                lambda carry: jnp.all(carry <= x
                                      ),  # We look for "le" in the graph
                body_fun,
                x)
            return carry

        shape = (3, 2)
        x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)

        jax_comp = jax.xla_computation(f_while)(x)
        backend = jax.lib.xla_bridge.get_backend()
        modules = backend.compile(jax_comp).hlo_modules()
        jax_opt_hlo = modules[0].to_string()
        print(f"JAX OPT HLO = {jax_opt_hlo}")

        self.CheckOpMetadata(f_while, x, [
            tf_test_util.OpMetadataGraph(
                tf_type="Sin",
                source_file=__file__,
                source_line=user_frame.line_num + 4,
                op_name="jax2tf(f_while)/while/body/sin",
                op_type="sin"),
            tf_test_util.OpMetadataGraph(
                tf_type="LessEqual",
                source_file=__file__,
                source_line=user_frame.line_num + 8,
                op_name="jax2tf(f_while)/while/body_pred/le",
                op_type="le"),
        ])
示例#8
0
  def test_op_metadata_while_and_cond(self):
    self.skipTest("include_xla_op_metadata not yet enabled")
    # An example with while and cond
    # The user_frame is used to compute line numbers for ops in the test.
    user_frame = source_info_util.user_frame(source_info_util.current())
    def f_while_cond(x):
      def body_fun(i_acc):
        i, acc = i_acc
        return (i + 1,
                (jnp.cos(acc) +
                 lax.cond(jnp.mod(i, 2) == 0,
                          lambda acc: jnp.sin(acc),
                          lambda acc: acc,
                          acc)))

      _, acc = lax.while_loop(
          lambda i_acc: i_acc[0] <= 5,
          body_fun, (0, x))
      return acc

    x = np.ones((2, 3), np.float32)
    self.CheckOpMetadata(
        f_while_cond, x,
        [tf_test_util.OpMetadataGraph(tf_type="Cos",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 5,
                                      op_name="jax2tf(f_while_cond)/while/body/cos",
                                      op_type="cos"),
         tf_test_util.OpMetadataGraph(tf_type="Sin",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 7,
                                      op_name="jax2tf(f_while_cond)/while/body/branch_1_fun/sin",
                                      op_type="sin"),
         tf_test_util.OpMetadataGraph(tf_type="FloorMod",
                                      source_file=__file__,
                                      source_line=user_frame.line_num + 6,
                                      op_name="jax2tf(f_while_cond)/while/body/rem",
                                      op_type="rem"),
         ]
    )