Exemplo n.º 1
0
    def test_create_source_map_multiline_call(self):
        test_fn = basic_definitions.function_with_multiline_call
        source_map = self._create_source_map(test_fn)
        module_path = tf_inspect.getsourcefile(test_fn)

        # Origin line numbers below should match those in basic_definitions.py
        fn_start = inspect.getsourcelines(test_fn)[1]

        call_loc = origin_info.LineLocation('test_filename', 3)
        self.assertIn(call_loc, source_map)
        self.assertEqual(source_map[call_loc].loc.lineno, fn_start + 2)
        self.assertEqual(source_map[call_loc].loc.filename, module_path)
        self.assertEqual(source_map[call_loc].function_name,
                         'function_with_multiline_call')
        self.assertEqual(source_map[call_loc].source_code_line,
                         '  return range(')

        second_arg_loc = origin_info.LineLocation('test_filename', 5)
        self.assertIn(second_arg_loc, source_map)
        self.assertEqual(source_map[second_arg_loc].loc.lineno, fn_start + 4)
        self.assertEqual(source_map[second_arg_loc].loc.filename, module_path)
        self.assertEqual(source_map[second_arg_loc].function_name,
                         'function_with_multiline_call')
        self.assertEqual(source_map[second_arg_loc].source_code_line,
                         '      x + 1,')
Exemplo n.º 2
0
 def fake_origin(self, function, line_offset):
     _, lineno = tf_inspect.getsourcelines(function)
     filename = tf_inspect.getsourcefile(function)
     lineno += line_offset
     loc = origin_info.LineLocation(filename, lineno)
     origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
     return loc, origin
Exemplo n.º 3
0
 def fake_origin(self, function, line_offset):
   _, lineno = tf_inspect.getsourcelines(function)
   filename = tf_inspect.getsourcefile(function)
   lineno += line_offset
   loc = origin_info.LineLocation(filename, lineno)
   origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code')
   return loc, origin
def resolve(nodes, source, function=None):
    """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None

    # TODO(mdan): Pull this to a separate utility.
    code_reader = six.StringIO(source)
    comment_map = {}
    for token in tokenize.generate_tokens(code_reader.readline):
        tok_type, tok_string, loc, _, _ = token
        srow, _ = loc
        if tok_type == tokenize.COMMENT:
            comment_map[srow] = tok_string.strip()[1:].strip()

    source_lines = source.split('\n')
    for node in nodes:
        for n in gast.walk(node):
            if not hasattr(n, 'lineno'):
                continue

            lineno_in_body = n.lineno

            source_code_line = source_lines[lineno_in_body - 1]
            if function:
                source_lineno = function_lineno + lineno_in_body
                function_name = function.__name__
            else:
                source_lineno = lineno_in_body
                function_name = None

            location = Location(function_filepath, source_lineno, n.col_offset)
            origin = OriginInfo(location, function_name, source_code_line,
                                comment_map.get(source_lineno))
            anno.setanno(n, anno.Basic.ORIGIN, origin)
Exemplo n.º 5
0
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name,
                          source_code_line, comment_map.get(source_lineno))
      anno.setanno(n, anno.Basic.ORIGIN, origin)
Exemplo n.º 6
0
def resolve_entity(node, source, entity):
  """Like resolve, but extracts the context informartion from an entity."""
  lines, lineno = tf_inspect.getsourcelines(entity)
  filepath = tf_inspect.getsourcefile(entity)

  # Poor man's attempt at guessing the column offset: count the leading
  # whitespace. This might not work well with tabs.
  definition_line = lines[0]
  col_offset = len(definition_line) - len(definition_line.lstrip())

  resolve(node, source, filepath, lineno, col_offset)
Exemplo n.º 7
0
    def test_create_source_map_identity(self):
        test_fn = basic_definitions.simple_function
        source_map = self._create_source_map(test_fn)
        module_path = tf_inspect.getsourcefile(test_fn)

        # Origin line numbers below should match those in basic_definitions.py

        definition_loc = origin_info.LineLocation('test_filename', 1)
        self.assertIn(definition_loc, source_map)
        self.assertEqual(source_map[definition_loc].loc.lineno, 23)
        self.assertEqual(source_map[definition_loc].loc.filename, module_path)
        self.assertEqual(source_map[definition_loc].function_name,
                         'simple_function')
def resolve(node, source, function=None):
    """Adds an origin information to node and its subnodes.

  This allows us to map the original source code line numbers to generated
  source code.

  Args:
    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
        annotate with origin information.
    source: Text, the source code. Should satisfy relationship
        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
        unreliable.
    function: The original function. If it is None then only the line numbers
        and column offset will be set in the annotation, with the rest of the
        information being None.
  """
    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None

    # TODO(mdan): Pull this to a separate utility.
    code_reader = six.StringIO(source)
    comment_map = {}
    for token in tokenize.generate_tokens(code_reader.readline):
        tok_type, tok_string, loc, _, _ = token
        srow, _ = loc
        if tok_type == tokenize.COMMENT:
            comment_map[srow] = tok_string.strip()[1:].strip()

    source_lines = source.split('\n')
    for n in gast.walk(node):
        if not hasattr(n, 'lineno'):
            continue

        within_body_offset = n.lineno - node.lineno

        source_code_line = source_lines[n.lineno - 1]
        if function:
            source_lineno = function_lineno + within_body_offset
            function_name = function.__name__
        else:
            source_lineno = n.lineno
            function_name = None

        location = Location(function_filepath, source_lineno, n.col_offset)
        origin = OriginInfo(location, function_name, source_code_line,
                            comment_map.get(source_lineno))
        anno.setanno(n, anno.Basic.ORIGIN, origin)
Exemplo n.º 9
0
def resolve(node, source, function=None):
  """Adds an origin information to node and its subnodes.

  This allows us to map the original source code line numbers to generated
  source code.

  Args:
    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
        annotate with origin information.
    source: Text, the source code. Should satisfy relationship
        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
        unreliable.
    function: The original function. If it is None then only the line numbers
        and column offset will be set in the annotation, with the rest of the
        information being None.
  """
  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for n in gast.walk(node):
    if not hasattr(n, 'lineno'):
      continue

    within_body_offset = n.lineno - node.lineno

    source_code_line = source_lines[n.lineno - 1]
    if function:
      source_lineno = function_lineno + within_body_offset
      function_name = function.__name__
    else:
      source_lineno = n.lineno
      function_name = None

    location = Location(function_filepath, source_lineno, n.col_offset)
    origin = OriginInfo(location, function_name,
                        source_code_line, comment_map.get(source_lineno))
    anno.setanno(n, anno.Basic.ORIGIN, origin)
Exemplo n.º 10
0
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name, source_code_line)
      anno.setanno(n, anno.Basic.ORIGIN, origin)
Exemplo n.º 11
0
  def test_runtime_error_rewriting(self):

    def g(x, s):
      while tf.reduce_sum(x) > s:
        x //= 0
      return x

    def test_fn(x):
      return g(x, 10)

    compiled_fn = ag.to_graph(test_fn)

    with self.assertRaises(ag.TfRuntimeError) as error:
      with self.cached_session() as sess:
        x = compiled_fn(tf.constant([4, 8]))
        with ag.improved_errors(compiled_fn):
          sess.run(x)
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_test_fn_frames = 0
    num_g_frames = 0
    ag_output_filename = tf_inspect.getsourcefile(compiled_fn)
    for frame in custom_traceback:
      filename, _, fn_name, source_code = frame
      self.assertFalse(ag_output_filename in filename)
      self.assertFalse('control_flow_ops.py' in filename)
      self.assertFalse('ag__.' in fn_name)
      self.assertFalse('tf__g' in fn_name)
      self.assertFalse('tf__test_fn' in fn_name)
      found_correct_filename |= __file__ in filename
      num_test_fn_frames += int('test_fn' == fn_name and
                                'return g(x, 10)' in source_code)
      # This makes sure that the code is correctly rewritten from "x_1 //= 0" to
      # "x //= 0".
      num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
    self.assertTrue(found_correct_filename)
    self.assertEqual(num_test_fn_frames, 1)
    self.assertEqual(num_g_frames, 1)
Exemplo n.º 12
0
def resolve(node, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    node: The AST node for the function whose body nodes will be annotated.
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None
  source_lines = source.split('\n')
  for n in gast.walk(node):
    if hasattr(n, 'lineno'):
      # n.lineno is relative to the start of the enclosing function, so need to
      # offset it by the line of the function.
      source_code_line = source_lines[n.lineno - 1]
      if function:
        source_lineno = n.lineno + function_lineno - 1
        function_name = function.__name__
      else:
        source_lineno = n.lineno
        function_name = None
      anno.setanno(
          n, anno.Basic.ORIGIN,
          OriginInfo(function_filepath, function_name, source_lineno,
                     n.col_offset, source_code_line))
Exemplo n.º 13
0
def resolve(node, source, function=None):
    """Adds an origin information to all nodes inside the body of function.

  Args:
    node: The AST node for the function whose body nodes will be annotated.
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None
    source_lines = source.split('\n')
    for n in gast.walk(node):
        if hasattr(n, 'lineno'):
            # n.lineno is relative to the start of the enclosing function, so need to
            # offset it by the line of the function.
            source_code_line = source_lines[n.lineno - 1]
            if function:
                source_lineno = n.lineno + function_lineno - 1
                function_name = function.__name__
            else:
                source_lineno = n.lineno
                function_name = None
            anno.setanno(
                n, anno.Basic.ORIGIN,
                OriginInfo(function_filepath, function_name, source_lineno,
                           n.col_offset, source_code_line))
Exemplo n.º 14
0
  def test_graph_construction_error_rewriting_call_tree(self):

    def innermost(x):
      if x > 0:
        return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
      return tf.zeros((2, 3))

    def inner_caller():
      return innermost(1.0)

    def caller():
      return inner_caller()

    with self.assertRaises(ag.GraphConstructionError) as error:
      graph = ag.to_graph(caller)
      graph()
    expected = error.exception
    custom_traceback = expected.custom_traceback
    found_correct_filename = False
    num_innermost_names = 0
    num_inner_caller_names = 0
    num_caller_names = 0
    ag_output_filename = tf_inspect.getsourcefile(graph)
    for frame in custom_traceback:
      filename, _, fn_name, _ = frame
      self.assertFalse('control_flow_ops.py' in filename)
      self.assertFalse(ag_output_filename in filename)
      found_correct_filename |= __file__ in filename
      self.assertNotEqual('tf__test_fn', fn_name)
      num_innermost_names += int('innermost' == fn_name)
      self.assertNotEqual('tf__inner_caller', fn_name)
      num_inner_caller_names += int('inner_caller' == fn_name)
      self.assertNotEqual('tf__caller', fn_name)
      num_caller_names += int('caller' == fn_name)
    self.assertTrue(found_correct_filename)
    self.assertEqual(num_innermost_names, 1)
    self.assertEqual(num_inner_caller_names, 1)
    self.assertEqual(num_caller_names, 1)
Exemplo n.º 15
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.traceable_stack."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import test_util
from tensorflow.python.framework import traceable_stack
from tensorflow.python.platform import googletest
from tensorflow.python.util import tf_inspect as inspect

_LOCAL_OBJECT = lambda x: x
_THIS_FILENAME = inspect.getsourcefile(_LOCAL_OBJECT)


class TraceableObjectTest(test_util.TensorFlowTestCase):
    def testSetFilenameAndLineFromCallerUsesCallersStack(self):
        t_obj = traceable_stack.TraceableObject(17)

        # Do not separate placeholder from the set_filename_and_line_from_caller()
        # call one line below it as it is used to calculate the latter's line
        # number.
        placeholder = lambda x: x
        result = t_obj.set_filename_and_line_from_caller()

        expected_lineno = inspect.getsourcelines(placeholder)[1] + 1
        self.assertEqual(expected_lineno, t_obj.lineno)
        self.assertEqual(_THIS_FILENAME, t_obj.filename)
Exemplo n.º 16
0
 def testGetSourceFile(self):
     self.assertEqual(
         __file__,
         tf_inspect.getsourcefile(test_decorated_function_with_defaults))
Exemplo n.º 17
0
 def testGetSourceFile(self):
   self.assertEqual(
       __file__,
       tf_inspect.getsourcefile(test_decorated_function_with_defaults))
Exemplo n.º 18
0
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.python.framework.traceable_stack."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import test_util
from tensorflow.python.framework import traceable_stack
from tensorflow.python.platform import googletest
from tensorflow.python.util import tf_inspect as inspect

_LOCAL_OBJECT = lambda x: x
_THIS_FILENAME = inspect.getsourcefile(_LOCAL_OBJECT)


class TraceableObjectTest(test_util.TensorFlowTestCase):

  def testSetFilenameAndLineFromCallerUsesCallersStack(self):
    t_obj = traceable_stack.TraceableObject(17)

    # Do not separate placeholder from the set_filename_and_line_from_caller()
    # call one line below it as it is used to calculate the latter's line
    # number.
    placeholder = lambda x: x
    result = t_obj.set_filename_and_line_from_caller()

    expected_lineno = inspect.getsourcelines(placeholder)[1] + 1
    self.assertEqual(expected_lineno, t_obj.lineno)