コード例 #1
0
    def _visit_For(self, node):
        assert isinstance(node, gast.For)
        if isinstance(node.iter, gast.Call):
            # for in range(var[0]|var.numpy()[0]) or for in enumerate(var|var.numpy())
            if isinstance(node.iter.func, gast.Name):
                if node.iter.func.id == "range" or node.iter.func.id == "enumerate":
                    for arg in node.iter.args:
                        self.visit(arg)
                else:
                    return
            # for in var.numpy()
            elif isinstance(node.iter.func, gast.Attribute):
                if node.iter.func.attr == 'numpy':
                    self._visit_Call(node.iter)
                else:
                    return
            else:
                return
        elif isinstance(node.iter, gast.Name):
            # for in var
            self.visit(node.iter)
        else:
            return

        for child_node in gast.walk(node):
            if isinstance(child_node, (gast.Continue, gast.Break)):
                self._visit_break_continue(child_node)
        return
コード例 #2
0
 def synchronize_lcds(self, node):
     node = FuseAttributes().visit(node)
     loads, lcds = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             name = child.targets[0].id
             if name in loads:
                 if name in lcds:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 lcds.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in lcds:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
コード例 #3
0
ファイル: optimization.py プロジェクト: zouzias/tangent
def dead_code_elimination(node):
    """Perform a simple form of dead code elimination on a Python AST.

  This method performs reaching definitions analysis on all function
  definitions. It then looks for the definition of variables that are not used
  elsewhere and removes those definitions.

  This function takes into consideration push and pop statements; if a pop
  statement is removed, it will also try to remove the accompanying push
  statement. Note that this *requires dead code elimination to be performed on
  the primal and adjoint simultaneously*.

  Args:
    node: The AST to optimize.

  Returns:
    The optimized AST.
  """
    to_remove = set(def_[1] for def_ in annotate.unused(node)
                    if not isinstance(def_[1], (gast.arguments, gast.For)))
    for n in list(to_remove):
        for succ in gast.walk(n):
            if anno.getanno(succ, 'push', False):
                to_remove.add(anno.getanno(succ, 'push'))
    transformers.Remove(to_remove).visit(node)
    anno.clearanno(node)
    return node
コード例 #4
0
 def visit_loop(self, node, update_mask=gast.NameConstant(value=None)):
     node = FuseAttributes().visit(node)
     loads, stores = defaultdict(list), set()
     for child in node.body:
         for n in gast.walk(child):
             if isinstance(n, gast.Name) and isinstance(n.ctx, gast.Load):
                 loads[n.id].append(n)
         if isinstance(child, gast.Assign):
             if len(child.targets) > 1:
                 raise NotImplementedError("cannot process LCD that is "
                                           "part of multiple assignment")
             name = child.targets[0].id
             if name in loads:
                 if name in stores:
                     raise NotImplementedError("cannot process LCD "
                                               "stored to twice")
                 # $var = $expr -> $var = $var._update($expr)
                 child.value = gast.Call(
                     gast.Attribute(gast.Name(name, gast.Load(), None),
                                    gast.Name('_update', gast.Load(), None),
                                    None), [child.value, update_mask], [])
                 stores.add(name)
     node = SplitAttributes().visit(node)
     synchronizes = []
     for name in stores:
         synchronize = gast.Assign(
             [gast.Name(name, gast.Store(), None)],
             gast.Call(
                 gast.Attribute(
                     gast.Name(name, gast.Load(), None),
                     gast.Name('_synchronize', gast.Load(), None), None),
                 [], []))
         synchronizes.append(synchronize)
     node.body.extend(synchronizes)
     return node
コード例 #5
0
def compute_same_identifier_edges(tree, ast_to_node_id):
    """Compute EXTRA_SAME_IDENTIFIER edges from an AST.

  These edges connect any two `Name` nodes with the same identifier, including

  Args:
    tree: The AST to construct an example for.
    ast_to_node_id: Dictionary that maps AST node ids to their graph node id.

  Returns:
    List of same-identifier edges.
  """
    result = []
    nodes_by_identifier = collections.defaultdict(list)
    for ast_node in gast.walk(tree):
        if isinstance(ast_node, gast.Name):
            graph_node_id = ast_to_node_id[id(ast_node)]
            identifier = ast_node.id  # pytype: disable=attribute-error
            for matching in nodes_by_identifier[identifier]:
                result.append(
                    (graph_node_id, matching, SAME_IDENTIFIER_EDGE_TYPE))
                result.append(
                    (matching, graph_node_id, SAME_IDENTIFIER_EDGE_TYPE))
            nodes_by_identifier[identifier].append(graph_node_id)
            result.append(
                (graph_node_id, graph_node_id, SAME_IDENTIFIER_EDGE_TYPE))

    return result
コード例 #6
0
    def test_nested_loop_vars(self):
        func = self.nested_for_loop_func
        test_func = inspect.getsource(func)
        gast_root = gast.parse(test_func)
        name_visitor = NameVisitor(gast_root)

        self.loop_var_names = [
            set(["j", "two"]),
            set(["i", "three", "b"]),
            set(["i", "j"])
        ]
        self.create_var_names = [set(), set(["b"]), set()]

        i = 0
        for node in gast.walk(gast_root):
            if isinstance(node, (gast.While, gast.For)):
                loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                    node)
                self.assertEqual(
                    loop_var_names,
                    self.loop_var_names[i],
                    msg="loop_var_names : {}, \nexpected loop_var_names : {}".
                    format(loop_var_names, self.loop_var_names[i]))
                self.assertEqual(
                    create_var_names,
                    self.create_var_names[i],
                    msg=
                    "i = {}\ncreate_var_names : {}, \nexpected create_var_names : {}"
                    .format(i, create_var_names, self.create_var_names[i]))
                i += 1
コード例 #7
0
    def _transform_var_shape_if_necessary(self, cond):
        need_transformed = False
        for child_node in gast.walk(cond):
            var_shape_node = None
            if isinstance(child_node, (gast.Attribute)):
                if self.is_var_shape(child_node):
                    var_shape_node = child_node
            elif isinstance(child_node, (gast.Name)):
                if child_node.id in self.name_to_var_shape:
                    var_shape_node = self.name_to_var_shape[child_node.id]

            if var_shape_node:
                need_transformed = True
                wrapper_node = self.node_to_wrapper_map.get(child_node)
                parent_node = wrapper_node.parent.node
                for field, value in gast.iter_fields(parent_node):
                    if child_node is value:
                        setattr(parent_node, field,
                                create_convert_shape_node(var_shape_node))
                        break
                    # Some child_node may be in a list such as gast.Compare
                    if isinstance(value, list):
                        has_converted_shape = False
                        for i, v in enumerate(value):
                            if child_node is v:
                                value[i] = create_convert_shape_node(
                                    var_shape_node)
                                has_converted_shape = True
                                break
                        if has_converted_shape:
                            break
        return need_transformed
コード例 #8
0
ファイル: parser.py プロジェクト: zhaojiawen1996/tensorflow
def _without_context(node, lines, minl, maxl):
    """Returns a clean node and source code without indenting and context."""
    for n in gast.walk(node):
        lineno = getattr(n, 'lineno', None)
        if lineno is not None:
            n.lineno = lineno - minl
        end_lineno = getattr(n, 'end_lineno', None)
        if end_lineno is not None:
            n.end_lineno = end_lineno - minl

    code_lines = lines[minl - 1:maxl]

    # Attempt to clean up surrounding context code.

    end_col_offset = getattr(node, 'end_col_offset', None)
    if end_col_offset is not None:
        # This is only available in 3.8.
        code_lines[-1] = code_lines[-1][:end_col_offset]

    col_offset = getattr(node, 'col_offset', None)
    if col_offset is None:
        # Older Python: try to find the "lambda" token. This is brittle.
        match = re.search(r'(?<!\w)lambda(?!\w)', code_lines[0])
        if match is not None:
            col_offset = match.start(0)

    if col_offset is not None:
        code_lines[0] = code_lines[0][col_offset:]

    code_block = '\n'.join([c.rstrip() for c in code_lines])

    return node, code_block
コード例 #9
0
def explicit_loop_indexes(node):
  node = ExplicitLoopIndexes().visit(node)
  for n in gast.walk(node):
    for key in ('active_in', 'active_out', 'active_gen', 'active_kill'):
      if anno.hasanno(n, key):
        anno.delanno(n, key)
  return node
コード例 #10
0
def run_analyses(node, analyses):
    """Perform dataflow analysis on all functions within an AST.

  Args:
    node: An AST node on which to run dataflow analysis.
    analyses: Either an instance of the Forward or Backward dataflow analysis
      class, or a list or tuple of them.

  Returns:
    node: The node, but now with annotations on the AST nodes containing the
    results of the dataflow analyses.
  """
    if not isinstance(analyses, (tuple, list)):
        analyses = (analyses, )
    for analysis in analyses:
        if not isinstance(analysis, (Forward, Backward)):
            raise TypeError('not a valid forward analysis object')

    for child_node in gast.walk(node):
        if isinstance(child_node, gast.FunctionDef):
            cfg_obj = CfgBuilder().build_cfg(child_node)
            for analysis in analyses:
                if isinstance(analysis, Backward):
                    analysis.visit(cfg_obj.exit)
                elif isinstance(analysis, Forward):
                    analysis.visit(cfg_obj.entry)
    for analysis in analyses:
        PropagateAnalysis(analysis).visit(node)
    return node
コード例 #11
0
ファイル: cfg.py プロジェクト: Eagle732/tensorflow
def run_analyses(node, analyses):
  """Perform dataflow analysis on all functions within an AST.

  Args:
    node: An AST node on which to run dataflow analysis.
    analyses: Either an instance of the Forward or Backward dataflow analysis
      class, or a list or tuple of them.

  Returns:
    node: The node, but now with annotations on the AST nodes containing the
    results of the dataflow analyses.
  """
  if not isinstance(analyses, (tuple, list)):
    analyses = (analyses,)
  for analysis in analyses:
    if not isinstance(analysis, (Forward, Backward)):
      raise TypeError('not a valid forward analysis object')

  for child_node in gast.walk(node):
    if isinstance(child_node, gast.FunctionDef):
      cfg_obj = CfgBuilder().build_cfg(child_node)
      for analysis in analyses:
        if isinstance(analysis, Backward):
          analysis.visit(cfg_obj.exit)
        elif isinstance(analysis, Forward):
          analysis.visit(cfg_obj.entry)
  for analysis in analyses:
    PropagateAnalysis(analysis).visit(node)
  return node
コード例 #12
0
    def test_compute_jumps_out_edges(self):
        tree = gast.parse(
            textwrap.dedent("""\
          def foo():        # tree.body[0]
            return          # tree.body[0].body[0]
            while True:     # tree.body[0].body[1]
              break         # tree.body[0].body[1].body[0]
              continue      # tree.body[0].body[1].body[1]
              return        # tree.body[0].body[1].body[2]
              while True:   # tree.body[0].body[1].body[3]
                break       # tree.body[0].body[1].body[3].body[0]
                return 4    # tree.body[0].body[1].body[3].body[1]
          """))

        expected_type = graph_edge_util.JUMPS_OUT_OF_EDGE_TYPE
        expected_targets = [
            (tree.body[0].body[0], tree.body[0], expected_type),
            (tree.body[0].body[1].body[0], tree.body[0].body[1],
             expected_type),
            (tree.body[0].body[1].body[1], tree.body[0].body[1],
             expected_type),
            (tree.body[0].body[1].body[2], tree.body[0], expected_type),
            (tree.body[0].body[1].body[3].body[0],
             tree.body[0].body[1].body[3], expected_type),
            (tree.body[0].body[1].body[3].body[1], tree.body[0],
             expected_type),
            (tree.body[0].body[1].body[3].body[1].value, tree.body[0],
             expected_type),
        ]

        # For this test, we pretend that the AST nodes are the node ids.
        targets = graph_edge_util.compute_jumps_out_edges(
            tree, {id(x): x
                   for x in gast.walk(tree)})
        self.assertCountEqual(targets, expected_targets)
コード例 #13
0
    def visit_Assign(self, node):
        if self._update_class_node_dict(node):
            return None

        for child_node in gast.walk(node.value):
            if isinstance(child_node, gast.Call):
                self._visit_Call(child_node)
        return node
コード例 #14
0
 def is_active(self, node):
     active_variables = anno.getanno(node, 'active_in')
     for succ in gast.walk(node):
         if (isinstance(succ, gast.Name)
                 and isinstance(succ.ctx, gast.Load)
                 and succ.id in active_variables):
             return True
     return False
コード例 #15
0
 def test_walk(self):
     code = 'x + 1'
     tree = gast.parse(code, mode='eval')
     dump = gast.dump(tree)
     norm = ("Expression(body=BinOp(left=Name(id='x', ctx=Load(), "
             "annotation=None), op=Add(), right=Num(n=1)))")
     self.assertEqual(dump, norm)
     self.assertEqual(len(list(gast.walk(tree))), 6)
コード例 #16
0
 def _visit_While(self, node):
     assert isinstance(node, gast.While)
     test = node.test
     self.generic_visit(test)
     for child_node in gast.walk(node):
         if isinstance(child_node, (gast.Continue, gast.Break)):
             self._visit_break_continue(child_node)
     return
コード例 #17
0
ファイル: cfg.py プロジェクト: zouzias/tangent
def forward(node, analysis):
    """Perform a given analysis on all functions within an AST."""
    if not isinstance(analysis, Forward):
        raise TypeError('not a valid forward analysis object')
    for succ in gast.walk(node):
        if isinstance(succ, gast.FunctionDef):
            cfg_obj = CFG.build_cfg(succ)
            analysis.visit(cfg_obj.entry)
    return node
コード例 #18
0
ファイル: annotations.py プロジェクト: yalechang/tangent
def clearanno(node):
    for succ in gast.walk(node):
        if hasattr(succ, ANNOTATION_FIELD):
            new = {}
            for anno in FIXED_ANNOTATIONS:
                if hasanno(succ, anno):
                    new[anno] = getanno(succ, anno)
            setattr(succ, ANNOTATION_FIELD, new)
    return node
コード例 #19
0
 def visit_Expr(self, node):
     value_node = node.value
     for child_node in gast.walk(value_node):
         if isinstance(child_node, gast.Call):
             if is_dygraph_api(child_node):
                 return
             else:
                 self._visit_Call(child_node)
     return node
コード例 #20
0
ファイル: origin_info.py プロジェクト: zjwangmin/tensorflow
def copy_origin(from_node, to_node):
    """Copies the origin info from a node to another, recursively."""
    origin = anno.Basic.ORIGIN.of(from_node, default=None)
    if origin is None:
        return
    if not isinstance(to_node, (list, tuple)):
        to_node = (to_node, )
    for node in to_node:
        for n in gast.walk(node):
            anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #21
0
 def visit_Compare(self, node):
     pre_control_flow_num = self.is_control_flow_num
     if not compare_with_none(node):
         self.generic_visit(node)
         for child in gast.walk(node):
             if isinstance(child, gast.Subscript):
                 self._visit_Subscript(child)
     if self.is_control_flow_num > pre_control_flow_num:
         self._compare_node_tenor_set.add(node)
     return node
コード例 #22
0
ファイル: origin_info.py プロジェクト: ZhangXinNan/tensorflow
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name,
                          source_code_line, comment_map.get(source_lineno))
      anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #23
0
def resolve(nodes, source, function=None):
    """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
    if not isinstance(nodes, (list, tuple)):
        nodes = (nodes, )

    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None

    # TODO(mdan): Pull this to a separate utility.
    code_reader = six.StringIO(source)
    comment_map = {}
    for token in tokenize.generate_tokens(code_reader.readline):
        tok_type, tok_string, loc, _, _ = token
        srow, _ = loc
        if tok_type == tokenize.COMMENT:
            comment_map[srow] = tok_string.strip()[1:].strip()

    source_lines = source.split('\n')
    for node in nodes:
        for n in gast.walk(node):
            if not hasattr(n, 'lineno'):
                continue

            lineno_in_body = n.lineno

            source_code_line = source_lines[lineno_in_body - 1]
            if function:
                source_lineno = function_lineno + lineno_in_body
                function_name = function.__name__
            else:
                source_lineno = lineno_in_body
                function_name = None

            location = Location(function_filepath, source_lineno, n.col_offset)
            origin = OriginInfo(location, function_name, source_code_line,
                                comment_map.get(source_lineno))
            anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #24
0
 def visit_Expr(self, node):
     value_node = node.value
     for child_node in gast.walk(value_node):
         if isinstance(child_node, gast.Call):
             # TODO(liym27):
             #  Considers that a dygraph api which modifies the input or has a output.
             if is_dygraph_api(child_node):
                 return
             else:
                 self._visit_Call(child_node)
     return node
コード例 #25
0
 def __call__(self, codeobj):
     cache = self.cache
     key = self.get_file_info(codeobj)
     result = cache.get(key)
     if result is not None:
         return result
     fname = key[0]
     cache[(fname, 0)] = mod_ast = gast.ast_to_gast(self.parse_file(fname))
     for obj in gast.walk(mod_ast):
         if isinstance(obj, gast.FunctionDef):
             cache[(fname, obj.lineno)] = obj
     return cache[key]
コード例 #26
0
 def test_loop_vars(self):
     for i in range(len(self.loop_funcs)):
         func = self.loop_funcs[i]
         test_func = inspect.getsource(func)
         gast_root = gast.parse(test_func)
         name_visitor = NameVisitor(gast_root)
         for node in gast.walk(gast_root):
             if isinstance(node, (gast.While, gast.For)):
                 loop_var_names, create_var_names = name_visitor.get_loop_var_names(
                     node)
                 self.assertEqual(loop_var_names, self.loop_var_names[i])
                 self.assertEqual(create_var_names, self.create_var_names[i])
コード例 #27
0
ファイル: analyzers.py プロジェクト: bentobox-dev/bento-box
def analyze_func(ast: AST) -> AST:
    """Annotate `FunctionDef` nodes in the given AST with additional infomation

    Walks through the `FunctionDef` nodes in given AST and annotates
    each node with the following info:
    - `n_args`: the function's arguments count.
    - `docstr`: the function's docstring if present, otherwise None
    - `is_empty`: whether the function is empty.
    - `is_generator`: whether the function produces a generator via `yield`.

    Args:
        ast:
            AST to scan for and annotate `FunctionDef` in.

    Returns:
        The given AST with the `FunctionDef` annotated with additional infomation.
    """
    # walk through AST to find FunctionDef nodes
    fn_asts = [n for n in gast.walk(ast) if isinstance(n, FunctionDef)]
    for fn_ast in fn_asts:
        fn_ast.n_args = len(fn_ast.args.args)
        fn_ast.docstr = gast.get_docstring(fn_ast)
        # detect empty if contains pass and/or just a docstrings
        fn_ast.is_empty = True
        for node in fn_ast.body:
            if isinstance(node, Pass):
                continue
            if (
                isinstance(node, Expr)
                and isinstance(node.value, Constant)
                and isinstance(node.value.value, str)
            ):
                continue
            fn_ast.is_empty = False
        # detect as generator if contains yield statement
        fn_ast.is_generator = any(
            isinstance(node, gast.Yield) for node in gast.walk(fn_ast)
        )

    return ast
コード例 #28
0
ファイル: ifelse_transformer.py プロジェクト: iducn/Paddle
 def visit_Compare(self, node):
     # Ignores child node with `if x` or `if x is None`
     # TODO(Aurelius84): `if tensor` will be supported in dygraph
     # and should be considered as is_control_flow.
     pre_control_flow_num = self.is_control_flow_num
     if not compare_with_none(node):
         self.generic_visit(node)
         for child in gast.walk(node):
             if isinstance(child, gast.Subscript):
                 self._visit_Subscript(child)
     if self.is_control_flow_num > pre_control_flow_num:
         self._compare_node_tenor_set.add(node)
     return node
コード例 #29
0
ファイル: fixes.py プロジェクト: zouzias/tangent
 def prepend_uninitialized_grads(self, node):
     if anno.hasanno(node, 'defined_in'):
         uses = (succ for succ in gast.walk(node)
                 if isinstance(succ, gast.Name)
                 and isinstance(succ.ctx, gast.Load))
         for use in uses:
             if ((anno.hasanno(use, 'adjoint_var')
                  or anno.hasanno(use, 'temp_adjoint_var'))
                     and use.id not in anno.getanno(node, 'defined_in')
                     and use.id not in self.added):
                 self.added.add(use.id)
                 self.insert_top(self._init(use))
     return node
コード例 #30
0
def resolve(node, source, function=None):
    """Adds an origin information to node and its subnodes.

  This allows us to map the original source code line numbers to generated
  source code.

  Args:
    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
        annotate with origin information.
    source: Text, the source code. Should satisfy relationship
        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
        unreliable.
    function: The original function. If it is None then only the line numbers
        and column offset will be set in the annotation, with the rest of the
        information being None.
  """
    if function:
        _, function_lineno = tf_inspect.getsourcelines(function)
        function_filepath = tf_inspect.getsourcefile(function)
    else:
        function_lineno = None
        function_filepath = None

    # TODO(mdan): Pull this to a separate utility.
    code_reader = six.StringIO(source)
    comment_map = {}
    for token in tokenize.generate_tokens(code_reader.readline):
        tok_type, tok_string, loc, _, _ = token
        srow, _ = loc
        if tok_type == tokenize.COMMENT:
            comment_map[srow] = tok_string.strip()[1:].strip()

    source_lines = source.split('\n')
    for n in gast.walk(node):
        if not hasattr(n, 'lineno'):
            continue

        within_body_offset = n.lineno - node.lineno

        source_code_line = source_lines[n.lineno - 1]
        if function:
            source_lineno = function_lineno + within_body_offset
            function_name = function.__name__
        else:
            source_lineno = n.lineno
            function_name = None

        location = Location(function_filepath, source_lineno, n.col_offset)
        origin = OriginInfo(location, function_name, source_code_line,
                            comment_map.get(source_lineno))
        anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #31
0
ファイル: origin_info.py プロジェクト: perfmjs/tensorflow
def resolve(node, source, function=None):
  """Adds an origin information to node and its subnodes.

  This allows us to map the original source code line numbers to generated
  source code.

  Args:
    node: gast.AST node. Should be a gast.FunctionDef. This is the node we
        annotate with origin information.
    source: Text, the source code. Should satisfy relationship
        `node in iter_tree(gast.parse(source))`; otherwise the lineno will be
        unreliable.
    function: The original function. If it is None then only the line numbers
        and column offset will be set in the annotation, with the rest of the
        information being None.
  """
  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  # TODO(mdan): Pull this to a separate utility.
  code_reader = six.StringIO(source)
  comment_map = {}
  for token in tokenize.generate_tokens(code_reader.readline):
    tok_type, tok_string, loc, _, _ = token
    srow, _ = loc
    if tok_type == tokenize.COMMENT:
      comment_map[srow] = tok_string.strip()[1:].strip()

  source_lines = source.split('\n')
  for n in gast.walk(node):
    if not hasattr(n, 'lineno'):
      continue

    within_body_offset = n.lineno - node.lineno

    source_code_line = source_lines[n.lineno - 1]
    if function:
      source_lineno = function_lineno + within_body_offset
      function_name = function.__name__
    else:
      source_lineno = n.lineno
      function_name = None

    location = Location(function_filepath, source_lineno, n.col_offset)
    origin = OriginInfo(location, function_name,
                        source_code_line, comment_map.get(source_lineno))
    anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #32
0
ファイル: anno.py プロジェクト: JonathanRaiman/tensorflow
def dup(node, copy_map, field_name='___pyct_anno'):
  """Recursively copies annotations in an AST tree.

  Args:
    node: ast.AST
    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
        key. All annotations with the source key will be copied to identical
        annotations with the destination key.
    field_name: str
  """
  for n in gast.walk(node):
    for k in copy_map:
      if hasanno(n, k, field_name):
        setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
コード例 #33
0
def dup(node, copy_map, field_name='___pyct_anno'):
    """Recursively copies annotations in an AST tree.

  Args:
    node: ast.AST
    copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
        key. All annotations with the source key will be copied to identical
        annotations with the destination key.
    field_name: str
  """
    for n in gast.walk(node):
        for k in copy_map:
            if hasanno(n, k, field_name):
                setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
コード例 #34
0
    def _visit_For(self, node):
        assert isinstance(node, gast.For)
        # TODO
        # self.is_control_flow_num += 1
        if not isinstance(node.iter, gast.Call):
            return
        if not isinstance(node.iter.func, gast.Name):
            return
        if node.iter.func.id != "range":
            return
        for arg in node.iter.args:
            self.visit(arg)

        for child_node in gast.walk(node):
            if isinstance(child_node, (gast.Continue, gast.Break)):
                self._visit_break_continue(child_node)
        return
コード例 #35
0
ファイル: origin_info.py プロジェクト: StephenOman/tensorflow
def resolve(nodes, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    nodes: Union[ast.AST, Iterable[ast.AST, ...]]
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if not isinstance(nodes, (list, tuple)):
    nodes = (nodes,)

  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None

  source_lines = source.split('\n')
  for node in nodes:
    for n in gast.walk(node):
      if not hasattr(n, 'lineno'):
        continue

      lineno_in_body = n.lineno

      source_code_line = source_lines[lineno_in_body - 1]
      if function:
        source_lineno = function_lineno + lineno_in_body
        function_name = function.__name__
      else:
        source_lineno = lineno_in_body
        function_name = None

      location = Location(function_filepath, source_lineno, n.col_offset)
      origin = OriginInfo(location, function_name, source_code_line)
      anno.setanno(n, anno.Basic.ORIGIN, origin)
コード例 #36
0
def resolve(node, source, function=None):
  """Adds an origin information to all nodes inside the body of function.

  Args:
    node: The AST node for the function whose body nodes will be annotated.
    source: Text, the source code string for the function whose body nodes will
      be annotated.
    function: Callable, the function that will have all nodes inside of it
      annotation with an OriginInfo annotation with key anno.Basic.ORIGIN.  If
      it is None then only the line numbers and column offset will be set in the
      annotation, with the rest of the information being None.

  Returns:
    A tuple of the AST node for function and a String containing its source
    code.
  """
  if function:
    _, function_lineno = tf_inspect.getsourcelines(function)
    function_filepath = tf_inspect.getsourcefile(function)
  else:
    function_lineno = None
    function_filepath = None
  source_lines = source.split('\n')
  for n in gast.walk(node):
    if hasattr(n, 'lineno'):
      # n.lineno is relative to the start of the enclosing function, so need to
      # offset it by the line of the function.
      source_code_line = source_lines[n.lineno - 1]
      if function:
        source_lineno = n.lineno + function_lineno - 1
        function_name = function.__name__
      else:
        source_lineno = n.lineno
        function_name = None
      anno.setanno(
          n, anno.Basic.ORIGIN,
          OriginInfo(function_filepath, function_name, source_lineno,
                     n.col_offset, source_code_line))
コード例 #37
0
def add_filename_field(node, filename):
    for descendant in ast.walk(node):
        descendant.filename = filename
コード例 #38
0
def contains_return(node):
  for n in gast.walk(node):
    if isinstance(n, gast.Return):
      return True
  return False