def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: Union[str, source_info_util.NameStack] = "", ) -> xc.OpMetadata: if config.jax_experimental_name_stack: eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params) else: assert isinstance(name_stack, str) eqn_str = name_stack + str_eqn_compact(primitive.name, params) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def _source_info_to_location( primitive: core.Primitive, params: Dict, source_info: source_info_util.SourceInfo, name_stack: str = "") -> ir.Location: eqn_str = name_stack + core.str_eqn_compact(primitive.name, params) frame = source_info_util.user_frame(source_info) if frame is None: loc = ir.Location.unknown() else: loc = ir.Location.file(xla._get_canonical_source_file(frame), frame.line_num, 1) loc = ir.Location.name(eqn_str, childLoc=loc) # TODO(phawkins): also include primitive.name as the operator type. return loc
def make_op_metadata( primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: str = "", ) -> xc.OpMetadata: eqn_str = name_stack + str_eqn_compact(primitive.name, params) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)