Example #1
0
File: xla.py Project: wayfeng/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: str = "",
                     ) -> xc.OpMetadata:
  eqn_str = str(pp.text(name_stack) +
                pp_eqn_compact(primitive.name, params, JaxprPpContext()))
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info) if source_info else None
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
Example #2
0
def _source_info_to_location(primitive: core.Primitive,
                             params: Dict,
                             source_info: source_info_util.SourceInfo,
                             name_stack: str = "") -> ir.Location:
    eqn_str = str(
        pp.text(name_stack) +
        core.pp_eqn_compact(primitive.name, params, core.JaxprPpContext()))
    frame = source_info_util.user_frame(source_info)
    if frame is None:
        loc = ir.Location.unknown()
    else:
        loc = ir.Location.file(xla._get_canonical_source_file(frame),
                               frame.line_num, 1)
    loc = ir.Location.name(eqn_str, childLoc=loc)
    # TODO(phawkins): also include primitive.name as the operator type.
    return loc