Пример #1
0
def GetTempDir():
  """Return a temporary directory for tests to use."""
  global _googletest_temp_dir
  if not _googletest_temp_dir:
    if os.environ.get('TEST_TMPDIR'):
      temp_dir = tempfile.mkdtemp(prefix=os.environ['TEST_TMPDIR'])
    else:
      first_frame = tf_inspect.stack()[-1][0]
      temp_dir = os.path.join(tempfile.gettempdir(),
                              os.path.basename(tf_inspect.getfile(first_frame)))
      temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))

    # Make sure we have the correct path separators.
    temp_dir = temp_dir.replace('/', os.sep)

    def delete_temp_dir(dirname=temp_dir):
      try:
        file_io.delete_recursively(dirname)
      except errors.OpError as e:
        logging.error('Error removing %s: %s', dirname, e)

    atexit.register(delete_temp_dir)

    _googletest_temp_dir = temp_dir

  return _googletest_temp_dir
Пример #2
0
 def testStack(self):
   expected_stack = inspect.stack()
   actual_stack = tf_inspect.stack()
   self.assertEqual(len(expected_stack), len(actual_stack))
   self.assertEqual(expected_stack[0][0], actual_stack[0][0])  # Frame object
   self.assertEqual(expected_stack[0][1], actual_stack[0][1])  # Filename
   self.assertEqual(expected_stack[0][2],
                    actual_stack[0][2] - 1)  # Line number
   self.assertEqual(expected_stack[0][3], actual_stack[0][3])  # Function name
   self.assertEqual(expected_stack[1:], actual_stack[1:])
Пример #3
0
def rewrite_graph_construction_error(source_map):
  """Rewrites errors raised by non-AG APIs inside AG generated code.

  This is called from the except handler inside an AutoGraph generated function
  (that is, during exception handling). Only rewrites the frames corresponding
  to the function that this is called from, so each function is responsible
  to call this to have its own frames rewritten.

  This function always raises an error.

  Args:
    source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
        map belonging to the calling function

  Raises:
    GraphConstructionError: The rewritten underlying error.
    Exception: The underlying error, if it could not be rewritten.
  """
  error_info = sys.exc_info()
  _, original_error, e_traceback = error_info
  assert original_error is not None
  try:
    _, _, _, func_name, _, _ = tf_inspect.stack()[1]
    if isinstance(original_error, GraphConstructionError):
      # TODO(mdan): This is incomplete.
      # The error might have bubbled through a non-converted function.
      cleaned_traceback = traceback.extract_tb(e_traceback)
      previous_traceback = original_error.custom_traceback
      cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
    else:
      cleaned_traceback = traceback.extract_tb(e_traceback)

    # Remove the frame corresponding to this function call.
    cleaned_traceback = cleaned_traceback[1:]

    cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback, func_name)

    if isinstance(original_error, GraphConstructionError):
      original_error.custom_traceback = cleaned_traceback
      new_error = original_error
    else:
      new_error = GraphConstructionError(original_error, cleaned_traceback)
  except Exception:
    logging.exception('Error while rewriting AutoGraph error:')
    # TODO(mdan): Should reraise here, removing the top frame as well.
    raise original_error
  else:
    raise new_error
  finally:
    # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
    del e_traceback
Пример #4
0
def _call_location():
  """Returns call location given level up from current call."""
  frame = tf_inspect.currentframe()
  if frame:
    # CPython internals are available, use them for performance.
    # walk back two frames to get to deprecated function caller.
    first_frame = frame.f_back
    second_frame = first_frame.f_back
    frame = second_frame if second_frame else first_frame
    return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno)
  else:
    # Slow fallback path
    stack = tf_inspect.stack(0)  # 0 avoids generating unused context
    entry = stack[2]
    return '%s:%d' % (entry[1], entry[2])
Пример #5
0
def GetTempDir():
  """Return a temporary directory for tests to use."""
  global _googletest_temp_dir
  if not _googletest_temp_dir:
    first_frame = tf_inspect.stack()[-1][0]
    temp_dir = os.path.join(tempfile.gettempdir(),
                            os.path.basename(tf_inspect.getfile(first_frame)))
    temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))

    def delete_temp_dir(dirname=temp_dir):
      try:
        file_io.delete_recursively(dirname)
      except errors.OpError as e:
        logging.error('Error removing %s: %s', dirname, e)

    atexit.register(delete_temp_dir)
    _googletest_temp_dir = temp_dir

  return _googletest_temp_dir
Пример #6
0
 def _get_benchmark_name(self):
     """Mostly copied from benchmark.py _get_name()."""
     stack = tf_inspect.stack()
     name = None
     for frame in stack[::-1]:
         f_locals = frame[0].f_locals
         f_self = f_locals.get("self", None)
         if isinstance(f_self, tf.test.Benchmark):
             name = frame[3]  # Get the method name
             # This is a hack to get around the fact that some methods might have a
             # disable_tfrt decorator around them. In that case a function called
             # 'decorated' wraps the real called function underneath and so we
             # peek one deeper into the stack to get the real name.
             if name == "decorated":
                 continue
             else:
                 break
     if name is None:
         raise ValueError("Unable to determine calling Benchmark function.")
     if context.is_tfrt_enabled():
         name = name + "_tfrt"
     return name
Пример #7
0
    def _get_name(self, overwrite_name=None):
        """Returns full name of class and method calling report_benchmark."""

        # Find the caller method (outermost Benchmark class)
        stack = tf_inspect.stack()
        calling_class = None
        name = None
        for frame in stack[::-1]:
            f_locals = frame[0].f_locals
            f_self = f_locals.get("self", None)
            if isinstance(f_self, Benchmark):
                calling_class = f_self  # Get the outermost stack Benchmark call
                name = frame[3]  # Get the method name
                break
        if calling_class is None:
            raise ValueError("Unable to determine calling Benchmark class.")

        # Use the method name, or overwrite_name is provided.
        name = overwrite_name or name
        # Prefix the name with the class name.
        class_name = type(calling_class).__name__
        name = "%s.%s" % (class_name, name)
        return name
Пример #8
0
def GetTempDir():
    """Return a temporary directory for tests to use."""
    global _googletest_temp_dir
    if not _googletest_temp_dir:
        if os.environ.get('TEST_TMPDIR'):
            temp_dir = tempfile.mkdtemp(prefix=os.environ['TEST_TMPDIR'])
        else:
            first_frame = tf_inspect.stack()[-1][0]
            temp_dir = os.path.join(
                tempfile.gettempdir(),
                os.path.basename(tf_inspect.getfile(first_frame)))
            temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))

        def delete_temp_dir(dirname=temp_dir):
            try:
                file_io.delete_recursively(dirname)
            except errors.OpError as e:
                logging.error('Error removing %s: %s', dirname, e)

        atexit.register(delete_temp_dir)
        _googletest_temp_dir = temp_dir

    return _googletest_temp_dir
Пример #9
0
  def _get_name(self, overwrite_name=None):
    """Returns full name of class and method calling report_benchmark."""

    # Find the caller method (outermost Benchmark class)
    stack = tf_inspect.stack()
    calling_class = None
    name = None
    for frame in stack[::-1]:
      f_locals = frame[0].f_locals
      f_self = f_locals.get("self", None)
      if isinstance(f_self, Benchmark):
        calling_class = f_self  # Get the outermost stack Benchmark call
        name = frame[3]  # Get the method name
        break
    if calling_class is None:
      raise ValueError("Unable to determine calling Benchmark class.")

    # Use the method name, or overwrite_name is provided.
    name = overwrite_name or name
    # Prefix the name with the class name.
    class_name = type(calling_class).__name__
    name = "%s.%s" % (class_name, name)
    return name
Пример #10
0
def line_number_above():
    return tf_inspect.stack()[1][2] - 1
Пример #11
0
def line_number_above():
  return tf_inspect.stack()[1][2] - 1
Пример #12
0
def rewrite_graph_construction_error(source_map):
    """Rewrites errors raised by non-AG APIs inside AG generated code.

  Meant to be called from the try/except block inside each AutoGraph generated
  function.  Only rewrites the traceback frames corresponding to the function
  that this is called from.  When we raise a GraphConstructionError at the end
  it is then caught by calling functions, where they can be responsible for
  rewriting their own frames.

  Args:
    source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
        AG generated code.

  Raises:
    GraphConstructionError: The rewritten underlying error.
    Exception: The underlying error, if it could not be rewritten.
  """
    error_info = sys.exc_info()
    _, original_error, e_traceback = error_info
    assert original_error is not None
    try:
        _, _, _, func_name, _, _ = tf_inspect.stack()[1]
        # The latest function call is added to the beginning of a traceback, but
        # when rewriting the traceback of multiple function calls the previous
        # functions' except blocks may have already rewritten their own frames so
        # we want to copy over all of the previous frames. We may have rewritten
        # previous frames only if the error is a GraphConstructionError.
        if isinstance(original_error, GraphConstructionError):
            cleaned_traceback = traceback.extract_tb(e_traceback)
            previous_traceback = original_error.custom_traceback
            cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
        else:
            cleaned_traceback = traceback.extract_tb(e_traceback)
        cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)

        current_frame_indices = []
        # This code is meant to be called from the try/except block that wraps a
        # function body.  Here we look for all frames that came from the function
        # that this wraps, look for any matching line numbers in the source
        # mapping, and then rewrite them if matches are found.
        for fi, frame in enumerate(cleaned_traceback):
            _, _, frame_func_name, _ = frame
            if frame_func_name == func_name:
                current_frame_indices.append(fi)
                break
        if current_frame_indices:
            _rewrite_frame(source_map, cleaned_traceback,
                           current_frame_indices)

        if isinstance(original_error, GraphConstructionError):
            original_error.custom_traceback = cleaned_traceback
            new_error = original_error
        else:
            new_error = GraphConstructionError(original_error,
                                               cleaned_traceback)
    except Exception:
        logging.exception('Error while rewriting AutoGraph error:')
        raise original_error
    else:
        raise new_error
    finally:
        # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
        del e_traceback
Пример #13
0
def rewrite_graph_construction_error(source_map):
  """Rewrites errors raised by non-AG APIs inside AG generated code.

  Meant to be called from the try/except block inside each AutoGraph generated
  function.  Only rewrites the traceback frames corresponding to the function
  that this is called from.  When we raise a GraphConstructionError at the end
  it is then caught by calling functions, where they can be responsible for
  rewriting their own frames.

  Args:
    source_map: Dict[CodeLocation, OriginInfo], a mapping between the user and
        AG generated code.

  Raises:
    GraphConstructionError: The rewritten underlying error.
    Exception: The underlying error, if it could not be rewritten.
  """
  error_info = sys.exc_info()
  _, original_error, e_traceback = error_info
  assert original_error is not None
  try:
    _, _, _, func_name, _, _ = tf_inspect.stack()[1]
    # The latest function call is added to the beginning of a traceback, but
    # when rewriting the traceback of multiple function calls the previous
    # functions' except blocks may have already rewritten their own frames so
    # we want to copy over all of the previous frames. We may have rewritten
    # previous frames only if the error is a GraphConstructionError.
    if isinstance(original_error, GraphConstructionError):
      cleaned_traceback = traceback.extract_tb(e_traceback)
      previous_traceback = original_error.custom_traceback
      cleaned_traceback = [cleaned_traceback[0]] + previous_traceback
    else:
      cleaned_traceback = traceback.extract_tb(e_traceback)
    cleaned_traceback = _remove_rewrite_frames(cleaned_traceback)

    current_frame_indices = []
    # This code is meant to be called from the try/except block that wraps a
    # function body.  Here we look for all frames that came from the function
    # that this wraps, look for any matching line numbers in the source
    # mapping, and then rewrite them if matches are found.
    for fi, frame in enumerate(cleaned_traceback):
      _, _, frame_func_name, _ = frame
      if frame_func_name == func_name:
        current_frame_indices.append(fi)
        break
    if current_frame_indices:
      _rewrite_frame(source_map, cleaned_traceback, current_frame_indices)

    if isinstance(original_error, GraphConstructionError):
      original_error.custom_traceback = cleaned_traceback
      new_error = original_error
    else:
      new_error = GraphConstructionError(original_error, cleaned_traceback)
  except Exception:
    logging.exception('Error while rewriting AutoGraph error:')
    raise original_error
  else:
    raise new_error
  finally:
    # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
    del e_traceback