def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: str = "", ) -> xc.OpMetadata: eqn_str = str(pp.text(name_stack) + pp_eqn_compact(primitive.name, params, JaxprPpContext())) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) if source_info else None return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)
def make_op_metadata(primitive: core.Primitive, params: Dict, *, source_info: source_info_util.SourceInfo, name_stack: Union[str, source_info_util.NameStack] = "", ) -> xc.OpMetadata: if config.jax_experimental_name_stack: eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params) else: assert isinstance(name_stack, str) eqn_str = name_stack + str_eqn_compact(primitive.name, params) tracebacks[eqn_str] = source_info.traceback frame = source_info_util.user_frame(source_info) return xc.OpMetadata( op_type=primitive.name, op_name=eqn_str, source_file=_get_canonical_source_file(frame) if frame else None, source_line=frame.line_num if frame else None)