Exemple #1
0
def user_context(c):
    prev = _source_info_context.context
    _source_info_context.context = c or _source_info_context.context
    filtered_tb = None
    try:
        yield
    except Exception as e:
        if c is None or has_user_context(e):
            raise
        # TODO(phawkins): remove the following condition after Jaxlib 0.1.66 is the
        # minimum.
        if not hasattr(c, 'as_python_traceback'):
            raise
        filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
        if filtered_tb:
            msg = traceback_util.format_exception_only(e)
            msg = f'{msg}\n\n{_message}'
            c = JaxStackTraceBeforeTransformation(msg).with_traceback(
                filtered_tb)
            c.__context__ = e.__context__
            c.__cause__ = e.__cause__
            c.__suppress_context__ = e.__suppress_context__
            e.__context__ = None
            e.__cause__ = c
        raise
    finally:
        _source_info_context.context = prev
        del filtered_tb
Exemple #2
0
def user_context(c: Optional[Traceback],
                 *,
                 name_stack: Optional[NameStack] = None):
    prev = _source_info_context.context
    _source_info_context.context = _source_info_context.context.replace(
        traceback=c, name_stack=name_stack)
    filtered_tb = None
    try:
        yield
    except Exception as e:
        if c is None or has_user_context(e):
            raise
        filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
        if filtered_tb:
            msg = traceback_util.format_exception_only(e)
            msg = f'{msg}\n\n{_message}'
            exp = JaxStackTraceBeforeTransformation(msg).with_traceback(
                filtered_tb)
            exp.__context__ = e.__context__
            exp.__cause__ = e.__cause__
            exp.__suppress_context__ = e.__suppress_context__
            e.__context__ = None
            e.__cause__ = exp
        raise
    finally:
        _source_info_context.context = prev
        del filtered_tb
Exemple #3
0
def user_context(c):
  prev = _source_info_context.context
  _source_info_context.context = c or _source_info_context.context
  filtered_tb = None
  try:
    yield
  except Exception as e:
    if c is None or has_user_context(e):
      raise
    filtered_tb = traceback_util.filter_traceback(c.as_python_traceback())
    if filtered_tb:
      msg = traceback_util.format_exception_only(e)
      msg = f'{msg}\n\n{_message}'
      c = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
      c.__context__ = e.__context__
      c.__cause__ = e.__cause__
      c.__suppress_context__ = e.__suppress_context__
      e.__context__ = None
      e.__cause__ = c
    raise
  finally:
    _source_info_context.context = prev
    del filtered_tb