def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: str = "", ) -> xc.OpMetadata: eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params, JaxprPpContext())) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def _source_info_to_location(primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: str = "") -> ir.Location: eqn_str = str( pp.text(name_stack) + core.pp_eqn_compact(primitive.name, params, core.JaxprPpContext())) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. return loc