예제 #1
0
파일: xla.py 프로젝트: wayfeng/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: str = "",
                     ) -> xc.OpMetadata:
  eqn_str = str(pp.text(name_stack) +
                pp_eqn_compact(primitive.name, params, JaxprPpContext()))
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info) if source_info else None
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)
예제 #2
0
파일: xla.py 프로젝트: John1Tang/jax
def make_op_metadata(primitive: core.Primitive,
                     params: Dict, *,
                     source_info: source_info_util.SourceInfo,
                     name_stack: Union[str, source_info_util.NameStack] = "",
                     ) -> xc.OpMetadata:
  if config.jax_experimental_name_stack:
    eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
  else:
    assert isinstance(name_stack, str)
    eqn_str = name_stack + str_eqn_compact(primitive.name, params)
  tracebacks[eqn_str] = source_info.traceback
  frame = source_info_util.user_frame(source_info)
  return xc.OpMetadata(
        op_type=primitive.name,
        op_name=eqn_str,
        source_file=_get_canonical_source_file(frame) if frame else None,
        source_line=frame.line_num if frame else None)