def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]
    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
Beispiel #2
0
  def _process_variable_assignment(self, source, targets):
    # Special case: constructors.
    if isinstance(source, gast.Call):
      func = source.func
      if anno.hasanno(func, 'live_val'):
        func_obj = anno.getanno(func, 'live_val')
        if tf_inspect.isclass(func_obj):
          anno.setanno(source, 'is_constructor', True)
          anno.setanno(source, 'type', func_obj)
          anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
          # TODO(mdan): Raise an error if constructor has side effects.
          # We can have a whitelist of no-side-effects constructors.
          # We can also step inside the constructor and further analyze.

    # Multiple targets mean multiple assignment.
    for target in targets:
      # Tuple target means unpacking.
      if isinstance(target, gast.Tuple):
        for i, target_item in enumerate(target.elts):
          # Two cases here:
          #   1. Static unpacking, e.g. a, b = c, d
          #   2. Dynamic unpacking, e.g. a, b = c
          # The former case is optimized away.
          if isinstance(source, (gast.Tuple, gast.List)):
            source_item = source.elts[i]
          else:
            source_item = gast.Subscript(source, gast.Index(i), ctx=None)
          self._process_variable_assignment(source_item, (target_item,))
      elif isinstance(target, (gast.Name, gast.Attribute)):
        target_symbol = anno.getanno(target, anno.Basic.QN)
        self.scope.setval(target_symbol, source)
      else:
        raise ValueError(
            'assignment target has unknown type: %s' % target_item)
Beispiel #3
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      # Symbols targeted by the "set_type" marker function are assigned the data
      # type that it specified.
      if anno.getanno(node.func, 'live_val') is utils.set_element_type:

        if len(node.args) < 2 or len(node.args) > 3:
          raise ValueError('"%s" must have either two or three parameters'
                           % self.context.type_annotation_func)
        if len(node.args) == 2:
          target_arg, type_arg = node.args
          shape_arg = parser.parse_expression('None')
        else:
          target_arg, type_arg, shape_arg = node.args
        if not anno.hasanno(target_arg, anno.Basic.QN):
          raise ValueError('the first argument of "%s" must by a symbol' %
                           utils.set_element_type)
        # TODO(mdan): This is vulnerable to symbol renaming.
        element_type = type_arg
        element_shape = shape_arg

        target_symbol = anno.getanno(target_arg, anno.Basic.QN)
        # Find the definition of this symbol and annotate it with the given
        # data type. That in turn will cause future uses of the symbol
        # to receive the same type annotation.
        definition = self.scope.getval(target_symbol)
        anno.setanno(node, 'element_type', element_type)
        anno.setanno(node, 'element_shape', element_shape)
        anno.setanno(definition, 'element_type', element_type)
        anno.setanno(definition, 'element_shape', element_shape)
        # TODO(mdan): Should we update references between definition and here?
    return self.generic_visit(node)
  def test_if_attributes(self):

    def test_fn(a):
      if a > 0:
        a.b = -a.c
        d = 2 * a
      else:
        a.b = a.c
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE),
        ('a', 'a.c'),
        ('a.b', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'a.c'),
        ('a.b', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
        ('a', 'a.c', 'd'),
        ('a.b', 'd'),
        ('a', 'd'),
    )
Beispiel #5
0
    def _rename_compilable_function(self, node):
        assert anno.hasanno(node.func, 'live_val')
        assert anno.hasanno(node.func, 'fqn')
        target_entity = anno.getanno(node.func, 'live_val')
        target_fqn = anno.getanno(node.func, 'fqn')

        if not self._should_compile(node, target_fqn):
            return node

        if anno.hasanno(node, 'is_constructor'):
            new_name = self.ctx.namer.compiled_class_name(
                target_fqn, live_entity=target_entity)
            do_rename = True
        else:
            if anno.hasanno(node.func, 'parent_type'):
                owner_type = anno.getanno(node.func, 'parent_type')
            else:
                # Fallback - not reliable.
                owner_type = inspect_utils.getmethodclass(target_entity)
            new_name, do_rename = self.ctx.namer.compiled_function_name(
                target_fqn, live_entity=target_entity, owner_type=owner_type)

        if do_rename:
            if target_entity is not None:
                if tf_inspect.ismethod(target_entity):
                    # The renaming process will transform it into a regular function.
                    # TODO(mdan): Is this complete? How does it work with nested members?
                    node.args = [node.func.value] + node.args
            node.func = templates.replace('func_name', func_name=new_name)[0]
        return node
Beispiel #6
0
 def visit_Call(self, node):
   if anno.hasanno(node.func, 'live_val'):
     # Symbols targeted by the "set_type" marker function are assigned the data
     # type that it specified.
     if (anno.getanno(node.func, 'live_val') is
         self.context.type_annotation_func):
       # Expecting the actual type to be the second argument.
       if len(node.args) != 2:
         raise ValueError('"%s" must have exactly two parameters'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[0], anno.Basic.QN):
         raise ValueError('the first argument of "%s" must by a symbol'
                          % self.context.type_annotation_func)
       if not anno.hasanno(node.args[1], 'live_val'):
         raise ValueError(
             'the second argument of "%s" must be statically resolvable' %
             self.context.type_annotation_func)
       target_symbol = anno.getanno(node.args[0], anno.Basic.QN)
       element_type = anno.getanno(node.args[1], 'live_val')
       # Find the definition of this symbol and annotate it with the given
       # data type. That in turn will cause future uses of the symbol
       # to receive the same type annotation.
       definition = self.scope.getval(target_symbol)
       anno.setanno(node, 'element_type', element_type)
       anno.setanno(definition, 'element_type', element_type)
       # TODO(mdan): Should we update references between definition and here?
   return self.generic_visit(node)
  def test_nested_function(self):

    def test_fn(a):

      def f(x):
        y = x * x
        return y

      b = a
      for i in a:
        c = b
        b -= f(i)
      return b, c

    node, _ = self._parse_and_analyze(test_fn)
    fn_def_node = node.body[0].body[0]

    self.assertScopeIsRmc(
        anno.getanno(fn_def_node,
                     NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'),
        ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i'))
    self.assertScopeIsRmc(
        anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
            'x',
            'y',
        ))
  def test_call_args_subscripts(self):

    def foo(*_):
      pass

    def test_fn(a):
      b = 1
      c = 2
      foo(a[0], a[b])
      return a[c]

    node, _ = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[2].value
    self.assertScopeIsRmc(
        anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
        ('a', 'a[0]', 'a[b]', 'b'),
        (),
        (),
    )
    self.assertScopeIsRmc(
        anno.getanno(call_node, NodeAnno.ARGS_SCOPE).parent,
        ('a', 'a[0]', 'a[b]', 'a[c]', 'b', 'c', 'foo'),
        ('b', 'c'),
        ('a', 'b', 'c'),
    )
Beispiel #9
0
  def visit_Call(self, node):
    # If the function is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.nocompile_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least an argument.')
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if self.context.recursive:
        node = self._insert_dynamic_conversion(node)
      else:
        # Unresolved functions are allowed in non-recursive mode.
        pass
    return node
Beispiel #10
0
  def test_call_with_composite_names(self):

    def foo(*_):
      pass

    def test_fn(a):
      foo(a.b, a.c)
      if a > 0:
        a.b = 2
      else:
        d = 2
        d.e = a.c
        f = d.e + 1
        a.c = f

    node = self._parse_and_analyze(test_fn)
    call_node = node.body[0].body[0].value
    self.assertScopeIsRmc(
        anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), (),
        ())
    if_node = node.body[0].body[1]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a',), ('a.b',), ())
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'a.c', 'd', 'd.e', 'f'), ('a.c', 'd', 'd.e', 'f'), ('d', 'f'))
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
        ('y', 'z'))
    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
        ('x', 'y', 'u'), ('y', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
Beispiel #12
0
  def visit(self, node):
    """Depth-first walking the CFG, applying dataflow info propagation."""
    # node.value is None only for the exit CfgNode.
    if not node.value:
      return

    if anno.hasanno(node.value, self.out_label):
      before = hash(anno.getanno(node.value, self.out_label))
    else:
      before = None
    preds = [
        anno.getanno(pred.value, self.out_label)
        for pred in node.prev
        if anno.hasanno(pred.value, self.out_label)
    ]
    if preds:
      incoming = functools.reduce(self.transfer_fn, preds[1:], preds[0])
    else:
      incoming = frozenset()
    anno.setanno(node.value, self.in_label, incoming)
    gen, kill = self.get_gen_kill(node, incoming)
    anno.setanno(node.value, self.gen_label, gen)
    anno.setanno(node.value, self.kill_label, kill)
    anno.setanno(node.value, self.out_label, (incoming - kill) | gen)

    if hash(anno.getanno(node.value, self.out_label)) != before:
      for succ in node.next:
        self.visit(succ)
 def visit_For(self, node):
   node.target = self.visit(node.target)
   node.body = self._process_block(
       anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
   node.orelse = self._process_block(
       anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
   return node
Beispiel #14
0
def _build_source_map(node, code):
  """Return the Python objects represented by given AST.

  Compiling the AST code this way ensures that the source code is readable by
  e.g. `pdb` or `inspect`.

  Args:
    node: An AST node of the original generated code, before the source code is
      generated.
    code: The string representation of the source code for the newly generated
      code.

  Returns:
    Dict[CodeLocation, OriginInfo], a mapping between the user and AutoGraph
    generated code.
  """
  # After we have the final generated code we reparse it to get the final line
  # numbers. Then we walk through the generated and original ASTs in parallel
  # to build the mapping between the user and generated code.
  new_node = parser.parse_str(code)
  origin_info.resolve(new_node, code)
  source_mapping = {}
  for before, after in ast_util.parallel_walk(node, new_node):
    # Need both checks because if origin information is ever copied over to new
    # nodes then we need to rely on the fact that only the original user code
    # has the origin annotation.
    if (anno.hasanno(before, anno.Basic.ORIGIN) and
        anno.hasanno(after, anno.Basic.ORIGIN)):
      source_info = anno.getanno(before, anno.Basic.ORIGIN)
      new_line_number = anno.getanno(after, anno.Basic.ORIGIN).line_number
      source_mapping[new_line_number] = source_info
  return source_mapping
Beispiel #15
0
 def visit_Attribute(self, node):
   self.generic_visit(node)
   if anno.hasanno(node.value, 'live_val'):
     assert anno.hasanno(node.value, 'fqn')
     parent_object = anno.getanno(node.value, 'live_val')
     if not hasattr(parent_object, node.attr):
       raise AttributeError('%s has no attribute %s' % (parent_object,
                                                        node.attr))
     anno.setanno(node, 'parent_type', type(parent_object))
     anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
   # TODO(mdan): Investigate the role built-in annotations can play here.
   elif anno.hasanno(node.value, 'type'):
     parent_type = anno.getanno(node.value, 'type')
     if hasattr(parent_type, node.attr):
       # This should hold for static members like methods.
       # This would not hold for dynamic members like function attributes.
       # For the dynamic case, we simply leave the node without an annotation,
       # and let downstream consumers figure out what to do.
       anno.setanno(node, 'parent_type', parent_type)
       anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
       anno.setanno(node, 'fqn',
                    anno.getanno(node.value, 'type_fqn') + (node.attr,))
   elif isinstance(node.value, gast.Name):
     stem_name = node.value
     # All nonlocal symbols should be fully resolved.
     assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
     # TODO(mdan): Figure out what to do when calling attribute on local object
     # Maybe just leave as-is?
   return node
Beispiel #16
0
  def visit(self, cfg_node):
    # cfg_node.value is None for the exit node, which will be visited only once
    if not cfg_node.value:
      for pred in cfg_node.prev:
        self.visit(pred)
      return

    if anno.hasanno(cfg_node.value, self.in_label):
      before = hash(anno.getanno(cfg_node.value, self.in_label))
    else:
      before = None
    succs = [
        anno.getanno(succ.value, self.in_label)
        for succ in cfg_node.next
        if anno.hasanno(succ.value, self.in_label)
    ]
    if succs:
      incoming = functools.reduce(self.transfer_fn, succs[1:], succs[0])
    else:
      incoming = frozenset()
    anno.setanno(cfg_node.value, self.out_label, incoming)
    gen, kill = self.get_gen_kill(cfg_node, incoming)
    anno.setanno(cfg_node.value, self.gen_label, gen)
    anno.setanno(cfg_node.value, self.kill_label, kill)
    anno.setanno(cfg_node.value, self.in_label, (incoming - kill) | gen)
    if hash(anno.getanno(cfg_node.value, self.in_label)) != before:
      for pred in cfg_node.prev:
        self.visit(pred)
Beispiel #17
0
  def _generate_pop_operation(self, original_call_node, pop_var_name):
    assert isinstance(original_call_node.func, gast.Attribute)

    if original_call_node.args:
      pop_element = original_call_node.args[0]
    else:
      pop_element = parser.parse_expression('None')
    # The call will be something like "target.pop()", and the dtype is hooked to
    # target, hence the func.value.
    dtype = anno.getanno(
        original_call_node.func.value,
        'element_type',
        default=templates.replace_as_expression('None'))
    shape = anno.getanno(
        original_call_node.func.value,
        'element_shape',
        default=templates.replace_as_expression('None'))

    template = """
      target, pop_var_name = ag__.list_pop(
          target, element,
          opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
    """
    return templates.replace(
        template,
        target=original_call_node.func.value,
        pop_var_name=pop_var_name,
        element=pop_element,
        dtype=dtype,
        shape=shape)
Beispiel #18
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if not self._should_compile(node, target_fqn):
      return node

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace('func_name', func_name=new_name)[0]
    return node
Beispiel #19
0
  def test_reaching(self):

    def f(x):
      print(x)
      while True:
        x = x
        x = x
      return x

    node, ctx = self._parse_and_analyze(f, {})
    cfg.run_analyses(node, cfg.ReachingDefinitions(ctx))
    body = node.body[0].body
    # Only the argument reaches the expression
    def_in = anno.getanno(body[0], 'definitions_in')
    # One element, x, from arguments
    self.assertEqual(set(type(d[1]) for d in def_in), set((gast.arguments,)))

    while_body = body[1].body
    def_in = anno.getanno(while_body[0], 'definitions_in')
    # One definition, two possible sources.
    # - One from an assignment (if the loop is entered)
    # - The other from the arguments (if loop is not entered)
    self.assertEqual(
        set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))

    def_in = anno.getanno(while_body[1], 'definitions_in')
    # If we've reached this line, the only reaching definition of x is the
    # Assign node in previous line
    self.assertEqual(set(type(d[1]) for d in def_in), set((gast.Assign,)))

    def_in = anno.getanno(body[2], 'definitions_in')
    # Same situation as while_body[0]
    self.assertEqual(
        set(type(d[1]) for d in def_in), set((gast.arguments, gast.Assign)))
  def test_if_subscripts(self):

    def test_fn(a, b, c, e):
      if a > 0:
        a[b] = -a[c]
        d = 2 * a
      else:
        a[0] = e
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE),
        ('a', 'b', 'c', 'a[c]'),
        ('a', 'a[b]', 'd'),
        ('d',),
    )
    # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'e'),
        ('a', 'a[0]', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
        ('a', 'b', 'c', 'd', 'e', 'a[c]'),
        ('a', 'd', 'a[b]', 'a[0]'),
        ('a', 'b', 'c', 'd', 'e'),
    )
Beispiel #21
0
 def visit_With(self, node):
     self.generic_visit(node)
     incoming = anno.getanno(node.body[0], self.in_label)
     for item in node.items:
         incoming |= anno.getanno(item, self.in_label)
     outgoing = anno.getanno(node.body[-1], self.out_label)
     anno.setanno(node, self.in_label, incoming)
     anno.setanno(node, self.out_label, outgoing)
Beispiel #22
0
 def visit_With(self, node):
   self.generic_visit(node)
   incoming = anno.getanno(node.body[0], self.in_label)
   for item in node.items:
     incoming |= anno.getanno(item, self.in_label)
   outgoing = anno.getanno(node.body[-1], self.out_label)
   anno.setanno(node, self.in_label, incoming)
   anno.setanno(node, self.out_label, outgoing)
Beispiel #23
0
    def test_entity_scope_tracking(self):
        class TestTransformer(transformer.Base):

            # The choice of note to assign to is arbitrary. Using Assign because it's
            # easy to find in the tree.
            def visit_Assign(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

            # This will show up in the lambda function.
            def visit_BinOp(self, node):
                anno.setanno(node, 'enclosing_entities',
                             self.enclosing_entities)
                return self.generic_visit(node)

        tr = TestTransformer(self._simple_source_info())

        def test_function():
            a = 0

            class TestClass(object):
                def test_method(self):
                    b = 0

                    def inner_function(x):
                        c = 0
                        d = lambda y: (x + y)
                        return c, d

                    return b, inner_function

            return a, TestClass

        node, _ = parser.parse_entity(test_function)
        node = tr.visit(node)

        test_function_node = node.body[0]
        test_class = test_function_node.body[1]
        test_method = test_class.body[0]
        inner_function = test_method.body[1]
        lambda_node = inner_function.body[1].value

        a = test_function_node.body[0]
        b = test_method.body[0]
        c = inner_function.body[0]
        lambda_expr = lambda_node.body

        self.assertEqual((test_function_node, ),
                         anno.getanno(a, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method),
                         anno.getanno(b, 'enclosing_entities'))
        self.assertEqual(
            (test_function_node, test_class, test_method, inner_function),
            anno.getanno(c, 'enclosing_entities'))
        self.assertEqual((test_function_node, test_class, test_method,
                          inner_function, lambda_node),
                         anno.getanno(lambda_expr, 'enclosing_entities'))
Beispiel #24
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Beispiel #25
0
    def visit_For(self, node):
        self.generic_visit(node)

        self._validate_no_live_vars_created(node)

        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        body_closure = body_scope.modified - body_scope.created
        all_referenced = body_scope.referenced

        state = list(body_closure)

        state_ssf = [
            self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
        ]
        ssf_map = {
            name: ssf
            for name, ssf in zip(state, state_ssf) if str(name) != ssf
        }

        if len(state) == 1:
            state = state[0]
            state_ssf = state_ssf[0]
            state_ast_tuple = state
        else:
            state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
        node = templates.replace(
            template,
            state=state,
            state_ssf=state_ssf,
            state_ast_tuple=state_ast_tuple,
            iter_=node.iter,
            iterate=node.target,
            extra_test_name=self.ctx.namer.new_symbol('extra_test',
                                                      all_referenced),
            extra_test_expr=extra_test,
            body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
            body=node_body)

        return node
    def test_rename_symbols_annotations(self):
        node = parser.parse_str('a[i]')
        node = qual_names.resolve(node)
        anno.setanno(node, 'foo', 'bar')
        orig_anno = anno.getanno(node, 'foo')

        node = ast_util.rename_symbols(
            node, {qual_names.QN('a'): qual_names.QN('b')})

        self.assertIs(anno.getanno(node, 'foo'), orig_anno)
Beispiel #27
0
 def visit_While(self, node):
     self.generic_visit(node)
     incoming = anno.getanno(node.body[0], self.in_label)
     incoming |= anno.getanno(node.test, self.in_label)
     outgoing = anno.getanno(node.body[-1], self.out_label)
     if node.orelse:
         orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
         outgoing = self.transfer_fn(outgoing, orelse_outgoing)
     anno.setanno(node, self.in_label, incoming)
     anno.setanno(node, self.out_label, outgoing)
Beispiel #28
0
 def visit_While(self, node):
   self.generic_visit(node)
   incoming = anno.getanno(node.body[0], self.in_label)
   incoming |= anno.getanno(node.test, self.in_label)
   outgoing = anno.getanno(node.body[-1], self.out_label)
   if node.orelse:
     orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
     outgoing = self.transfer_fn(outgoing, orelse_outgoing)
   anno.setanno(node, self.in_label, incoming)
   anno.setanno(node, self.out_label, outgoing)
Beispiel #29
0
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
Beispiel #30
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
Beispiel #32
0
 def visit_For(self, node):
   self.generic_visit(node)
   incoming = set(anno.getanno(node.body[0], self.in_label))
   incoming -= set((anno.getanno(node.target, anno.Basic.QN),))
   outgoing = anno.getanno(node.body[-1], self.out_label)
   if node.orelse:
     orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
     outgoing = self.transfer_fn(outgoing, orelse_outgoing)
   anno.setanno(node, self.in_label, frozenset(incoming))
   anno.setanno(node, self.out_label, outgoing)
Beispiel #33
0
 def visit_For(self, node):
     self.generic_visit(node)
     incoming = set(anno.getanno(node.body[0], self.in_label))
     incoming -= set((anno.getanno(node.target, anno.Basic.QN), ))
     outgoing = anno.getanno(node.body[-1], self.out_label)
     if node.orelse:
         orelse_outgoing = anno.getanno(node.orelse[-1], self.out_label)
         outgoing = self.transfer_fn(outgoing, orelse_outgoing)
     anno.setanno(node, self.in_label, frozenset(incoming))
     anno.setanno(node, self.out_label, outgoing)
Beispiel #34
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = __ops.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Beispiel #35
0
  def visit_For(self, node):
    self.generic_visit(node)

    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_cond'):
      extra_cond = anno.getanno(node, 'extra_cond')
      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
    else:
      extra_cond = parser.parse_expression('True')

    template = """
      def extra_cond_name(state_ssf):
        return extra_cond_expr
      def body_name(iterate, state_ssf):
        body
        return state_ssf,
      state_ast_tuple = ag__.for_loop(
          iterated, extra_cond_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iterated=node.iter,
        iterate=node.target,
        extra_cond_name=self.context.namer.new_symbol('extra_cond',
                                                      all_referenced),
        extra_cond_expr=extra_cond,
        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
Beispiel #36
0
  def test_entity_scope_tracking(self):

    class TestTransformer(transformer.Base):

      # The choice of note to assign to is arbitrary. Using Assign because it's
      # easy to find in the tree.
      def visit_Assign(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

      # This will show up in the lambda function.
      def visit_BinOp(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

    tr = TestTransformer(self._context_for_testing())

    def test_function():
      a = 0

      class TestClass(object):

        def test_method(self):
          b = 0
          def inner_function(x):
            c = 0
            d = lambda y: (x + y)
            return c, d
          return b, inner_function
      return a, TestClass

    node, _ = parser.parse_entity(test_function)
    node = tr.visit(node)

    test_function_node = node.body[0]
    test_class = test_function_node.body[1]
    test_method = test_class.body[0]
    inner_function = test_method.body[1]
    lambda_node = inner_function.body[1].value

    a = test_function_node.body[0]
    b = test_method.body[0]
    c = inner_function.body[0]
    lambda_expr = lambda_node.body

    self.assertEqual(
        (test_function_node,), anno.getanno(a, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method),
                     anno.getanno(b, 'enclosing_entities'))
    self.assertEqual(
        (test_function_node, test_class, test_method, inner_function),
        anno.getanno(c, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method,
                      inner_function, lambda_node),
                     anno.getanno(lambda_expr, 'enclosing_entities'))
Beispiel #37
0
 def _validate_no_live_vars_created(self, node):
     body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
     live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
     live_vars_created_in_body = live_vars_out & body_scope.created
     if live_vars_created_in_body:
         raise ValueError(
             'The following variables are created inside the loop and used later:'
             '\n%s\n'
             'Variables must be declared outside loops because loops may not'
             ' necessarily execute.' %
             self._fmt_symbol_list(live_vars_created_in_body))
Beispiel #38
0
    def test_constructor_detection(self):
        def test_fn():
            opt = training.GradientDescentOptimizer(0.1)
            return opt

        node = self._parse_and_analyze(test_fn, {'training': training})
        call_node = node.body[0].body[0].value
        self.assertEquals(training.GradientDescentOptimizer,
                          anno.getanno(call_node, 'type'))
        self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                          anno.getanno(call_node, 'type_fqn'))
Beispiel #39
0
 def _validate_no_live_vars_created(self, node):
   body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
   live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
   live_vars_created_in_body = live_vars_out & body_scope.created
   if live_vars_created_in_body:
     raise ValueError(
         'The following variables are created inside the loop and used later:'
         '\n%s\n'
         'Variables must be declared outside loops because loops may not'
         ' necessarily execute.' % self._fmt_symbol_list(
             live_vars_created_in_body))
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = self._parse_and_analyze(test_fn, {'foo': foo})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Beispiel #42
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
Beispiel #43
0
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = self._parse_and_analyze(test_fn, {'foo': foo})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
def source_map(nodes, code, filename, indices_in_code):
    """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse_str(code)
    reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

    resolve(reparsed_nodes, code)
    result = {}

    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
        # Note: generated code might not be mapped back to its origin.
        # TODO(mdan): Generated code should always be mapped to something.
        origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
        final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
        if origin_info is None or final_info is None:
            continue

        line_loc = LineLocation(filename, final_info.loc.lineno)

        existing_origin = result.get(line_loc)
        if existing_origin is not None:
            # Overlaps may exist because of child nodes, but almost never to
            # different line locations. Exception make decorated functions, where
            # both lines are mapped to the same line in the AST.

            # Line overlaps: keep bottom node.
            if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                if existing_origin.loc.lineno >= origin_info.loc.lineno:
                    continue

            # In case of overlaps, keep the leftmost node.
            if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                continue

        result[line_loc] = origin_info

    return result
Beispiel #45
0
  def test_primitive_values(self):

    a = None

    def test_fn():
      return a

    node = self._parse_and_analyze(test_fn, {'a': True})
    retval_node = node.body[0].body[0].value
    if six.PY2:
      self.assertEqual(
          anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool'))
    else:
      self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
Beispiel #46
0
    def test_type_annotation(self):
        class Foo(object):
            pass

        def test_fn():
            f = []
            f = utils.set_element_type(f, Foo)
            return f

        node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils})
        f_def = node.body[0].body[0].value
        self.assertEqual(anno.getanno(f_def, 'element_type'), Foo)
        f_ref = node.body[0].body[1].value
        self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo)
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.autograph_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
    def test_for(self):
        def test_fn(a):
            b = a
            for _ in a:
                c = b
                b -= 1
            return b, c

        node, _ = self._parse_and_analyze(test_fn)
        for_node = node.body[0].body[1]
        self.assertScopeIsRmc(anno.getanno(for_node, NodeAnno.BODY_SCOPE),
                              ('b', ), ('b', 'c'), ('c', ))
        self.assertScopeIsRmc(
            anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent,
            ('a', 'b', 'c'), ('b', 'c', '_'), ('a', 'b', 'c', '_'))
Beispiel #49
0
    def test_class_members_in_with_stmt(self):
        def test_fn(x):
            with session.Session() as sess:
                sess.run(x)

        node = self._parse_and_analyze(test_fn, {'session': session})
        constructor_call = node.body[0].body[0].items[0].context_expr
        self.assertEquals(session.Session,
                          anno.getanno(constructor_call, 'type'))
        self.assertEquals((session.__name__, 'Session'),
                          anno.getanno(constructor_call, 'type_fqn'))

        method_call = node.body[0].body[0].body[0].value.func
        self.assertEquals(session.Session.run,
                          anno.getanno(method_call, 'live_val'))
Beispiel #50
0
    def test_basic(self):
        node = ast.Name()

        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')

        anno.setanno(node, 'foo', 3)
        self.assertTrue(anno.hasanno(node, 'foo'))
        self.assertEqual(3, anno.getanno(node, 'foo'))

        anno.delanno(node, 'foo')
        self.assertFalse(anno.hasanno(node, 'foo'))
        with self.assertRaises(AttributeError):
            anno.getanno(node, 'foo')
Beispiel #51
0
    def visit_Name(self, node):
        self.generic_visit(node)
        if isinstance(node.ctx, gast.Load):
            assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
            symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
            assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
            symbol_is_modified = anno.getanno(node,
                                              NodeAnno.IS_MODIFIED_SINCE_ENTRY)
            assert anno.hasanno(node, NodeAnno.IS_PARAM), node
            symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)

            if not symbol_is_local and not symbol_is_param:
                if node.id in self.literals:
                    anno.setanno(node, 'live_val', self.literals[node.id])
                elif node.id in self.context.namespace:
                    obj = self.context.namespace[node.id]
                    anno.setanno(node, 'live_val', obj)
                    if hasattr(obj, '__name__'):
                        anno.setanno(node, 'fqn', (obj.__name__, ))
                    elif hasattr(obj, '__class__'):
                        obj_class = obj.__class__
                        anno.setanno(
                            node, 'fqn',
                            (obj_class.__module__, obj_class.__name__))
                    else:
                        # If the symbol value is for example a primitive, then it will not
                        # have a name.
                        pass
                else:
                    pass
                    # TODO (mdan): Should we raise an error here? id:997
                    # https://github.com/imdone/tensorflow/issues/998
                    # Can encounter this when:
                    #  * a symbol truly lacks reference
                    #  * a symbol is new, like the new name of a function we just renamed.
            else:
                pass
                # TODO (mdan): Attempt to trace its value through the local chain. id:730
                # https://github.com/imdone/tensorflow/issues/731
                # TODO (mdan): Use type annotations as fallback. id:700
                # https://github.com/imdone/tensorflow/issues/701

            if not symbol_is_modified:
                if node.id in self.context.arg_values:
                    obj = self.context.arg_values[node.id]
                    anno.setanno(node, 'live_val', obj)
                    anno.setanno(node, 'fqn', (obj.__class__.__name__, ))
        return node
 def visit_While(self, node):
     self.generic_visit(node.test)
     node.body = self._process_loop_block(
         node.body, anno.getanno(node, NodeAnno.BODY_SCOPE))
     for n in node.orelse:
         self.generic_visit(n)
     return node
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Beispiel #54
0
 def visit_Attribute(self, node):
     node = self.generic_visit(node)
     if anno.hasanno(node.value, anno.Basic.QN):
         anno.setanno(
             node, anno.Basic.QN,
             QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
     return node
    def test_nested_if(self):
        def test_fn(b):
            if b > 0:
                if b < 5:
                    a = b
                else:
                    a = b * b
            return a

        node, _ = self._parse_and_analyze(test_fn)
        inner_if_node = node.body[0].body[0].body[0]
        self.assertScopeIsRmc(anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE),
                              ('b', ), ('a', ), ('a', ))
        self.assertScopeIsRmc(
            anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b', ),
            ('a', ), ('a', ))
Beispiel #56
0
  def _track_symbol(self, node):
    # This can happen when we have an attribute (or subscript) on a function
    # call.  Example: a().b
    if not anno.hasanno(node, anno.Basic.QN):
      return
    qn = anno.getanno(node, anno.Basic.QN)

    if isinstance(node.ctx, gast.Store):
      self.scope.mark_write(qn)
    elif isinstance(node.ctx, gast.Load):
      self.scope.mark_read(qn)
    elif isinstance(node.ctx, gast.Param):
      # Param contexts appear in function defs, so they have the meaning of
      # defining a variable.
      # TODO(mdan): This bay be incorrect with nested functions.
      # For nested functions, we'll have to add the notion of hiding args from
      # the parent scope, not writing to them.
      self.scope.mark_creation(qn)
      self.scope.mark_param(qn)
    else:
      raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))

    anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
    anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
                 self.scope.is_modified_since_entry(qn))
    anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))

    if self._in_return_statement:
      self.scope.mark_returned(qn)
Beispiel #57
0
    def visit_Subscript(self, node):
        node = self.generic_visit(node)
        if not isinstance(node.slice, gast.Index):
            # TODO(mdan): It might make more sense to wave them through.
            raise NotImplementedError('non-index slice')

        if not isinstance(node.ctx, gast.Load):
            # Index writes are handled at a higher level, one at which the rvalue is
            # also available.
            return node

        dtype = anno.getanno(node.value,
                             'element_type',
                             default=templates.replace_as_expression('None'))

        template = """
      ag__.get_item(
          target,
          key,
          opts=ag__.GetItemOpts(element_dtype=dtype))
    """
        return templates.replace_as_expression(template,
                                               target=node.value,
                                               key=node.slice,
                                               dtype=dtype)
    def visit_FunctionDef(self, node):
        parent_analyzer = self.current_analyzer
        subgraph = self.graphs[node]

        # Preorder tree processing:
        #  1. if this is a child function, the parent was already analyzed and it
        #     has the proper state value for the subgraph's entry
        #  2. analyze the current function body
        #  2. recursively walk the subtree; child functions will be processed
        analyzer = Analyzer(subgraph, self.definition_factory)
        if parent_analyzer is not None:
            # Wire the state between the two subgraphs' analyzers.
            parent_out_state = parent_analyzer.out[
                parent_analyzer.graph.index[node]]
            # Exception: symbols modified in the child function are local to it
            body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
            parent_out_state -= body_scope.modified
            analyzer.extra_in[node.args] = parent_out_state

        # Complete the analysis for the local function and annotate its body.
        analyzer.visit_forward()

        # Recursively process any remaining subfunctions.
        self.current_analyzer = analyzer
        # Note: not visiting name, decorator_list and returns because they don't
        # apply to this anlysis.
        # TODO(mdan): Should we still process the function name?
        node.args = self.visit(node.args)
        node.body = self.visit_block(node.body)
        self.current_analyzer = parent_analyzer

        return node
    def visit_node(self, node):
        prev_live_in = self.in_[node]

        if anno.hasanno(node.ast_node, anno.Static.SCOPE):
            node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

            gen = node_scope.used | self.extra_gen.get(node.ast_node,
                                                       frozenset())
            # TODO(mdan): verify whether composites' parents need to be added.
            # E.g. if x.y is live whether x needs to be added. Theoretically the
            # activity analysis should have both so that wouldn't be needed.
            kill = node_scope.modified

            live_out = set()
            for n in node.next:
                live_out |= self.in_[n]
            live_in = gen | (live_out - kill)

        else:
            # Nodes that don't have a scope annotation are assumed not to touch any
            # symbols.
            # This Name node below is a literal name, e.g. False
            assert isinstance(node.ast_node,
                              (gast.Name, gast.Continue, gast.Break)), type(
                                  node.ast_node)
            live_in = prev_live_in
            live_out = live_in

        self.in_[node] = live_in
        self.out[node] = live_out

        # TODO(mdan): Move this to the superclass?
        return prev_live_in != live_in
Beispiel #60
0
    def visit_While(self, node):
        scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
        break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

        node.test = self.visit(node.test)
        node.body, break_used = self._process_body(node.body, break_var)
        # A break in the else clause applies to the containing scope.
        node.orelse = self.visit_block(node.orelse)

        if break_used:
            # Python's else clause only triggers if the loop exited cleanly (e.g.
            # break did not trigger).
            guarded_orelse = self._guard_if_present(node.orelse, break_var)

            template = """
        var_name = tf.constant(False)
        while code and not var_name:
          body
        else:
          orelse
      """
            node = templates.replace(template,
                                     var_name=break_var,
                                     test=node.test,
                                     body=node.body,
                                     orelse=guarded_orelse)

        return node