Example #1
0
File: xla.py Project: John1Tang/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: Union[str, source_info_util.NameStack] = "",
                     ) -> xc.OpMetadata:
  if config.jax_experimental_name_stack:
    eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
  else:
    assert isinstance(name_stack, str)
    eqn_str = name_stack + str_eqn_compact(primitive.name, params)
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info)
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
Example #2
0
File: mlir.py Project: GJBoth/jax
def _source_info_to_location(
    primitive: core.Primitive, params: Dict,
    source_info: source_info_util.SourceInfo,
    name_stack: str = "") -> ir.Location:
  eqn_str = name_stack + core.str_eqn_compact(primitive.name, params)
  frame = source_info_util.user_frame(source_info)
  if frame is None:
    loc = ir.Location.unknown()
  else:
    loc = ir.Location.file(xla._get_canonical_source_file(frame),
                           frame.line_num, 1)
  loc = ir.Location.name(eqn_str, childLoc=loc)
  # TODO(phawkins): also include primitive.name as the operator type.
  return loc
Example #3
0
def make_op_metadata(
    primitive: core.Primitive,
    params: Dict,
    *,
    source_info: source_info_util.SourceInfo,
    name_stack: str = "",
) -> xc.OpMetadata:
    eqn_str = name_stack + str_eqn_compact(primitive.name, params)
    tracebacks[eqn_str] = source_info.traceback
    frame = source_info_util.user_frame(source_info) if source_info else None
    return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)