Beispiel #1
0
  def test_if_subscripts(self):

    def test_fn(a, b, c, e):
      if a > 0:
        a[b] = -a[c]
        d = 2 * a
      else:
        a[0] = e
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE),
        ('a', 'b', 'c', 'a[c]'),
        ('a[b]', 'd'),
        ('d',),
    )
    # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'e'),
        ('a[0]', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
        ('a', 'b', 'c', 'd', 'e', 'a[c]'),
        ('d', 'a[b]', 'a[0]'),
        ('a', 'b', 'c', 'd', 'e'),
    )
  def test_resolve(self):

    def test_fn(x):
      """Docstring."""
      return x  # comment

    node, source = parser.parse_entity(test_fn)
    fn_node = node.body[0]

    origin_info.resolve(fn_node, source)

    origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 1)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  """Docstring."""')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
Beispiel #3
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
        ('y', 'z'))
    # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
        ('x', 'y', 'u'), ('y', 'u'))
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
Beispiel #4
0
  def test_if(self):

    def test_fn(x):
      if x > 0:
        x = -x
        y = 2 * x
        z = -y
      else:
        x = 2 * x
        y = -x
        u = -y
      return z, u

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'))
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'),
        ('x', 'y', 'z', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
        ('x', 'y', 'u'))
    self.assertScopeIs(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
        ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
  def test_resolve(self):

    source = """
        def test_fn(x):
          '''Docstring.'''
          return x  # comment
    """
    source = textwrap.dedent(source)

    node = parser.parse_str(source)

    origin_info.resolve(node, source)

    origin = anno.getanno(node, anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 2)
    self.assertEqual(origin.loc.col_offset, 0)
    self.assertEqual(origin.source_code_line, 'def test_fn(x):')
    self.assertIsNone(origin.comment)

    origin = anno.getanno(node.body[0], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 3)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, "  '''Docstring.'''")
    self.assertIsNone(origin.comment)

    origin = anno.getanno(node.body[1], anno.Basic.ORIGIN)
    self.assertEqual(origin.loc.lineno, 4)
    self.assertEqual(origin.loc.col_offset, 2)
    self.assertEqual(origin.source_code_line, '  return x  # comment')
    self.assertEqual(origin.comment, 'comment')
 def visit_For(self, node):
   node.target = self.visit(node.target)
   node.body = self._process_block(
       anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
   node.orelse = self._process_block(
       anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
   return node
  def _get_loop_state(self, node):
    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
    reserved_symbols = body_scope.referenced

    # Note that it doesn't matter whether the variables are live after the loop.
    # If the loop modifies them nonlocally (e.g. the result of an iteration
    # depends on the previous iteration), then they need to be included in
    # the loop state, regardless of whether they are later used or not.
    loop_state = body_scope.modified & live_in

    undefined_lives = loop_state - defined_in
    # Only simple variables must be defined. The composite ones will be
    # implicitly checked at runtime.
    undefined_simple_lives = {v for v in undefined_lives if v.is_simple()}
    if undefined_simple_lives:
      raise NameError(
          'cannot convert loop: it includes symbols that are undefined'
          ' when entering the loop: {}'.format(
              self._fmt_symbols(undefined_simple_lives)))

    live_defs_in_loop = (body_scope.modified - live_in) & live_out
    if live_defs_in_loop:
      # TODO(mdan): Include reference to explanation why.
      raise NotImplementedError(
          'cannot convert loop: it includes symbols that are defined'
          ' inside the loop, but used later: {}. To fix, initialize'
          ' these symbols before the loop'.format(
              self._fmt_symbols(live_defs_in_loop)))

    return loop_state, reserved_symbols
Beispiel #8
0
  def test_if_attributes(self):

    def test_fn(a):
      if a > 0:
        a.b = -a.c
        d = 2 * a
      else:
        a.b = a.c
        d = 1
      return d

    node, _ = self._parse_and_analyze(test_fn)
    if_node = node.body[0].body[0]
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE),
        ('a', 'a.c'),
        ('a.b', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
        ('a', 'a.c'),
        ('a.b', 'd'),
        ('d',),
    )
    self.assertScopeIsRmc(
        anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
        ('a', 'a.c', 'd'),
        ('a.b', 'd'),
        ('a', 'd'),
    )
  def test_origin_info_preserved_in_moved_nodes(self):

    class TestTransformer(transformer.Base):

      def visit_If(self, node):
        return node.body

    tr = TestTransformer(self._simple_context())

    def test_fn():
      x = 1
      if x > 0:
        x = 1
        x += 3
      return x

    node, source = parser.parse_entity(test_fn, future_features=())
    origin_info.resolve(node, source)
    node = tr.visit(node)

    assign_node = node.body[1]
    aug_assign_node = node.body[2]
    self.assertEqual(
        anno.getanno(assign_node, anno.Basic.ORIGIN).loc.lineno, 4)
    self.assertEqual(
        anno.getanno(aug_assign_node, anno.Basic.ORIGIN).loc.lineno, 5)
Beispiel #10
0
  def _rename_compilable_function(self, node):
    assert anno.hasanno(node.func, 'live_val')
    assert anno.hasanno(node.func, 'fqn')
    target_entity = anno.getanno(node.func, 'live_val')
    target_fqn = anno.getanno(node.func, 'fqn')

    if anno.hasanno(node, 'is_constructor'):
      new_name = self.ctx.namer.compiled_class_name(
          target_fqn, live_entity=target_entity)
      do_rename = True
    else:
      if anno.hasanno(node.func, 'parent_type'):
        owner_type = anno.getanno(node.func, 'parent_type')
      else:
        # Fallback - not reliable.
        owner_type = inspect_utils.getmethodclass(target_entity)
      new_name, do_rename = self.ctx.namer.compiled_function_name(
          target_fqn, live_entity=target_entity, owner_type=owner_type)

    if do_rename:
      if target_entity is not None:
        if tf_inspect.ismethod(target_entity):
          # The renaming process will transform it into a regular function.
          # TODO(mdan): Is this complete? How does it work with nested members?
          node.args = [node.func.value] + node.args
      node.func = templates.replace_as_expression(
          'func_name', func_name=new_name)
    return node
Beispiel #11
0
  def visit_For(self, node):
    self.generic_visit(node)

    self._validate_no_live_vars_created(node)

    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    body_closure = body_scope.modified - body_scope.created
    all_referenced = body_scope.referenced

    state = list(body_closure)

    state_ssf = [
        self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
    ]
    ssf_map = {
        name: ssf
        for name, ssf in zip(state, state_ssf)
        if str(name) != ssf
    }

    if len(state) == 1:
      state = state[0]
      state_ssf = state_ssf[0]
      state_ast_tuple = state
    else:
      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)

    node_body = ast_util.rename_symbols(node.body, ssf_map)
    if anno.hasanno(node, 'extra_test'):
      extra_test = anno.getanno(node, 'extra_test')
      extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    else:
      extra_test = parser.parse_expression('True')

    template = """
      def extra_test_name(state_ssf):
        return extra_test_expr
      def body_name(loop_vars, state_ssf):
        # Workaround for PEP-3113
        iterate = loop_vars
        body
        return state_ssf,
      state_ast_tuple = ag__.for_stmt(
          iter_, extra_test_name, body_name, (state,))
    """
    node = templates.replace(
        template,
        state=state,
        state_ssf=state_ssf,
        state_ast_tuple=state_ast_tuple,
        iter_=node.iter,
        iterate=node.target,
        extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
        extra_test_expr=extra_test,
        body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
        body=node_body)

    return node
  def test_attribute_names(self):

    def test_fn():
      return constant_op.constant(0)

    node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
    self.assertEquals((constant_op.__name__, 'constant'),
                      anno.getanno(func_node, 'fqn'))
Beispiel #13
0
  def test_rename_symbols_annotations(self):
    node = parser.parse_str('a[i]')
    node = qual_names.resolve(node)
    anno.setanno(node, 'foo', 'bar')
    orig_anno = anno.getanno(node, 'foo')

    node = ast_util.rename_symbols(node,
                                   {qual_names.QN('a'): qual_names.QN('b')})

    self.assertIs(anno.getanno(node, 'foo'), orig_anno)
  def test_entity_scope_tracking(self):

    class TestTransformer(transformer.Base):

      # The choice of note to assign to is arbitrary. Using Assign because it's
      # easy to find in the tree.
      def visit_Assign(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

      # This will show up in the lambda function.
      def visit_BinOp(self, node):
        anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
        return self.generic_visit(node)

    tr = TestTransformer(self._simple_context())

    def test_function():
      a = 0

      class TestClass(object):

        def test_method(self):
          b = 0
          def inner_function(x):
            c = 0
            d = lambda y: (x + y)
            return c, d
          return b, inner_function
      return a, TestClass

    node, _ = parser.parse_entity(test_function, future_features=())
    node = tr.visit(node)

    test_function_node = node
    test_class = test_function_node.body[1]
    test_method = test_class.body[0]
    inner_function = test_method.body[1]
    lambda_node = inner_function.body[1].value

    a = test_function_node.body[0]
    b = test_method.body[0]
    c = inner_function.body[0]
    lambda_expr = lambda_node.body

    self.assertEqual(
        (test_function_node,), anno.getanno(a, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method),
                     anno.getanno(b, 'enclosing_entities'))
    self.assertEqual(
        (test_function_node, test_class, test_method, inner_function),
        anno.getanno(c, 'enclosing_entities'))
    self.assertEqual((test_function_node, test_class, test_method,
                      inner_function, lambda_node),
                     anno.getanno(lambda_expr, 'enclosing_entities'))
Beispiel #15
0
 def _validate_no_live_vars_created(self, node):
   body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
   live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
   live_vars_created_in_body = live_vars_out & body_scope.created
   if live_vars_created_in_body:
     raise ValueError(
         'The following variables are created inside the loop and used later:'
         '\n%s\n'
         'Variables must be declared outside loops because loops may not'
         ' necessarily execute.' % self._fmt_symbol_list(
             live_vars_created_in_body))
Beispiel #16
0
  def test_constructor_detection(self):

    def test_fn():
      opt = training.GradientDescentOptimizer(0.1)
      return opt

    node = self._parse_and_analyze(test_fn, {'training': training})
    call_node = node.body[0].body[0].value
    self.assertEquals(training.GradientDescentOptimizer,
                      anno.getanno(call_node, 'type'))
    self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
                      anno.getanno(call_node, 'type_fqn'))
Beispiel #17
0
 def _try_resolve_target(self, node):
   """Works for methods of objects of known type."""
   if anno.hasanno(node, 'live_val'):
     return anno.getanno(node, 'live_val')
   if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
     owner_type = anno.getanno(node, 'type')
     if hasattr(owner_type, node.attr):
       return getattr(owner_type, node.attr)
     else:
       raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
                        (owner_type, node.attr))
   return None
  def test_namespace(self):

    def foo():
      return 'bar'

    def test_fn():
      return foo()

    node = self._parse_and_analyze(test_fn, {'foo': foo})
    func_node = node.body[0].body[0].value.func
    self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
    self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
Beispiel #19
0
  def visit_Call(self, node):
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')

      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None

      if self._function_is_compilable(target_entity):
        if self._should_compile(node, target_fqn):
          node = self._rename_compilable_function(node)
        else:
          node = self.generic_visit(node)
          return node

      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)

      elif inspect_utils.isbuiltin(target_entity):
        # Note: Any builtin that passed the builtins converter is assumed to be
        # safe for graph mode.
        return node

      elif inspect_utils.isnamedtuple(target_entity):
        # Although not compilable, we assume they are safe for graph mode.
        node = self.generic_visit(node)
        return node

      else:
        # TODO(mdan): Instert dynamic conversion here instead.
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      # Special cases
      # TODO(mdan): These need a systematic review - there may be more.

      # 1. super() calls - these are preserved. The class conversion mechanism
      # will ensure that they return the correct value.
      if ast_util.matches(node, 'super(_)'):
        return node

      # 2. super().method calls - these are preserved as well, when the
      # conversion processes the entire class.
      if (ast_util.matches(node, 'super(_)._(_)') and
          self.ctx.info.owner_type is not None):
        return node

      node = self._insert_dynamic_conversion(node)
    return node
Beispiel #20
0
  def test_params(self):

    def test_fn(a, b):  # pylint: disable=unused-argument
      return b

    node, _ = self._parse_and_analyze(test_fn)
    fn_node = node.body[0]
    body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
    self.assertScopeIs(body_scope, ('b',), ())
    self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b'))

    args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
    self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
  def test_primitive_values(self):

    a = None

    def test_fn():
      return a

    node = self._parse_and_analyze(test_fn, {'a': True})
    retval_node = node.body[0].body[0].value
    if six.PY2:
      self.assertEqual(
          anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool'))
    else:
      self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
Beispiel #22
0
  def visit_Call(self, node):
    # If the function call is wrapped by one of the marker decorators,
    # consider it graph ready.
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if target_entity in self.ctx.program.options.strip_decorators:
        if len(node.args) < 1:
          raise ValueError(
              'Found call to decorator function "%s", but it had no arguments. '
              'A decorator needs at least one positional argument.' %
              target_entity)
        anno.setanno(node.args[0], 'graph_ready', True)

    self.generic_visit(node)
    if anno.hasanno(node.func, 'live_val'):
      target_entity = anno.getanno(node.func, 'live_val')
      if anno.hasanno(node.func, 'fqn'):
        target_fqn = anno.getanno(node.func, 'fqn')
      else:
        target_fqn = None
      if self._function_is_compilable(target_entity):
        node = self._rename_compilable_function(node)
      elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
        # TODO(mdan): Should we replace these with equivalent TF ops instead?
        node = self._wrap_to_py_func_single_return(
            node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
      else:
        raise NotImplementedError(
            'py_func with return values (unknown function)')
    else:
      if anno.hasanno(node.func, anno.Basic.QN):
        # Special-case a few builtins that otherwise go undetected. This
        # normally doesn't pose a problem, but the dict built-in doesn't
        # work with inspect.getargspec which is required for dynamic functions.
        # Note: expecting this is resilient to aliasing (e.g.
        # dict = an_evil_dict), because in those cases the regular mechanisms
        # process a simple user function.
        qn = anno.getanno(node.func, anno.Basic.QN)
        # Add items to this list as needed.
        if str(qn) in ('dict',):
          return node

      if ast_util.matches(node, 'super(_)'):
        # super() calls are preserved. The class conversion mechanism will
        # ensure that they return the correct value.
        return node

      if self.ctx.program.options.recursive:
        node = self._insert_dynamic_conversion(node)
    return node
def create_source_map(nodes, code, filename, indices_in_code):
  """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code)
  reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

  resolve(reparsed_nodes, code)
  result = {}

  for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
    # Note: generated code might not be mapped back to its origin.
    # TODO(mdan): Generated code should always be mapped to something.
    origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
    final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
    if origin_info is None or final_info is None:
      continue

    line_loc = LineLocation(filename, final_info.loc.lineno)

    existing_origin = result.get(line_loc)
    if existing_origin is not None:
      # Overlaps may exist because of child nodes, but almost never to
      # different line locations. Exception make decorated functions, where
      # both lines are mapped to the same line in the AST.

      # Line overlaps: keep bottom node.
      if existing_origin.loc.line_loc == origin_info.loc.line_loc:
        if existing_origin.loc.lineno >= origin_info.loc.lineno:
          continue

      # In case of overlaps, keep the leftmost node.
      if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
        continue

    result[line_loc] = origin_info

  return result
Beispiel #24
0
  def test_class_members_in_with_stmt(self):

    def test_fn(x):
      with session.Session() as sess:
        sess.run(x)

    node = self._parse_and_analyze(test_fn, {'session': session})
    constructor_call = node.body[0].body[0].items[0].context_expr
    self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
    self.assertEquals((session.__name__, 'Session'),
                      anno.getanno(constructor_call, 'type_fqn'))

    method_call = node.body[0].body[0].body[0].value.func
    self.assertEquals(session.Session.run, anno.getanno(method_call,
                                                        'live_val'))
Beispiel #25
0
  def get_definition_directive(self, node, directive, arg, default):
    """Returns the unique directive argument for a symbol.

    See lang/directives.py for details on directives.

    Example:
       # Given a directive in the code:
       ag.foo_directive(bar, baz=1)

       # One can write for an AST node Name(id='bar'):
       get_definition_directive(node, ag.foo_directive, 'baz')

    Args:
      node: ast.AST, the node representing the symbol for which the directive
        argument is needed.
      directive: Callable[..., Any], the directive to search.
      arg: str, the directive argument to return.
      default: Any

    Raises:
      ValueError: if conflicting annotations have been found
    """
    defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
    if not defs:
      return default

    arg_values_found = []
    for def_ in defs:
      if (directive in def_.directives and arg in def_.directives[directive]):
        arg_values_found.append(def_.directives[directive][arg])

    if not arg_values_found:
      return default

    if len(arg_values_found) == 1:
      return arg_values_found[0]

    # If multiple annotations reach the symbol, they must all match. If they do,
    # return any of them.
    first_value = arg_values_found[0]
    for other_value in arg_values_found[1:]:
      if not ast_util.matches(first_value, other_value):
        qn = anno.getanno(node, anno.Basic.QN)
        raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
                         (qn, directive.__name__, arg,
                          compiler.ast_to_source(other_value).strip(),
                          compiler.ast_to_source(first_value).strip()))
    return first_value
 def _visit_and_reindent(self, nodes):
   new_nodes = []
   current_dest = new_nodes
   alias_map = {}
   reindent_requested = False
   for n in nodes:
     n = self.visit(n)
     # NOTE: the order in which these statements execute is important; in
     # particular, watch out for ending up with cycles in the AST.
     if alias_map:
       n = ast_util.rename_symbols(n, alias_map)
     if isinstance(n, (list, tuple)):
       current_dest.extend(n)
     else:
       current_dest.append(n)
     if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
       reindent_requested = True
       new_dest, new_alias_map = anno.getanno(
           n, anno.Basic.INDENT_BLOCK_REMAINDER)
       anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
       new_alias_map.update(alias_map)
       alias_map = new_alias_map
       current_dest = new_dest
   if reindent_requested and not current_dest:
     # TODO(mdan): There may still be something that could be done.
     raise ValueError('Unable to insert statement into the computation flow: '
                      'it is not followed by any computation which '
                      'the statement could gate.')
   return new_nodes
Beispiel #27
0
  def test_nested_if(self):

    def test_fn(b):
      if b > 0:
        if b < 5:
          a = b
        else:
          a = b * b
      return a

    node, _ = self._parse_and_analyze(test_fn)
    inner_if_node = node.body[0].body[0].body[0]
    self.assertScopeIs(
        anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',))
    self.assertScopeIs(
        anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',))
  def visit_While(self, node):
    scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
    break_var = self.ctx.namer.new_symbol('break_', scope.referenced)

    node.test = self.visit(node.test)
    node.body, break_used = self._process_body(node.body, break_var)
    # A break in the else clause applies to the containing scope.
    node.orelse = self.visit_block(node.orelse)

    if break_used:
      # Python's else clause only triggers if the loop exited cleanly (e.g.
      # break did not trigger).
      guarded_orelse = self._guard_if_present(node.orelse, break_var)

      template = """
        var_name = tf.constant(False)
        while test and not var_name:
          body
        else:
          orelse
      """
      node = templates.replace(
          template,
          var_name=break_var,
          test=node.test,
          body=node.body,
          orelse=guarded_orelse)

    return node
Beispiel #29
0
  def test_for(self):

    def test_fn(a):
      b = a
      for _ in a:
        c = b
        b -= 1
      return b, c

    node, _ = self._parse_and_analyze(test_fn)
    for_node = node.body[0].body[1]
    self.assertScopeIs(
        anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
    self.assertScopeIs(
        anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
        ('b', 'c', '_'))
Beispiel #30
0
  def visit_node(self, node):
    prev_live_in = self.in_[node]

    if anno.hasanno(node.ast_node, anno.Static.SCOPE):
      node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)

      gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset())
      # TODO(mdan): verify whether composites' parents need to be added.
      # E.g. if x.y is live whether x needs to be added. Theoretically the
      # activity analysis should have both so that wouldn't be needed.
      kill = node_scope.modified

      live_out = set()
      for n in node.next:
        live_out |= self.in_[n]
      live_in = gen | (live_out - kill)

    else:
      # Nodes that don't have a scope annotation are assumed not to touch any
      # symbols.
      # This Name node below is a literal name, e.g. False
      assert isinstance(node.ast_node,
                        (gast.Name, gast.Continue, gast.Break)), type(
                            node.ast_node)
      live_in = prev_live_in
      live_out = live_in

    self.in_[node] = live_in
    self.out[node] = live_out

    # TODO(mdan): Move this to the superclass?
    return prev_live_in != live_in
Beispiel #31
0
    def visit(self, node):
        if not isinstance(node, gast.AST):
            # This is not that uncommon a mistake: various node bodies are lists, for
            # example, posing a land mine for transformers that need to recursively
            # call `visit`.  The error needs to be raised before the exception handler
            # below is installed, because said handler will mess up if `node` is not,
            # in fact, a node.
            msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
                   ' visit lists of nodes, use "visit_block" instead').format(
                       type(node))
            raise ValueError(msg)

        did_enter_function = False
        local_scope_size_at_entry = len(self._local_scope_state)
        processing_expr_node = False

        parent_origin = self.ctx.current_origin
        if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
            did_enter_function = True
        elif isinstance(node, gast.Expr):
            processing_expr_node = True

        if did_enter_function:
            self._enclosing_entities.append(node)

        if anno.hasanno(node, anno.Basic.ORIGIN):
            self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

        if processing_expr_node:
            entry_expr_value = node.value

        if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
            result = super(Base, self).visit(node)
        self.ctx.current_origin = parent_origin

        # Adjust for consistency: replacing the value of an Expr with
        # an Assign node removes the need for the Expr node.
        if processing_expr_node:
            if isinstance(result,
                          gast.Expr) and result.value != entry_expr_value:
                # When the replacement is a list, it is assumed that the list came
                # from a template that contained a number of statements, which
                # themselves are standalone and don't require an enclosing Expr.
                if isinstance(result.value,
                              (list, tuple, gast.Assign, gast.AugAssign)):
                    result = result.value

        # By default, all replacements receive the origin info of the replaced node.
        if result is not node and result is not None:
            nodes_to_adjust = result
            if isinstance(result, (list, tuple)):
                nodes_to_adjust = result
            else:
                nodes_to_adjust = (result, )
            for n in nodes_to_adjust:
                if not anno.hasanno(n, anno.Basic.ORIGIN):
                    inherited_origin = anno.getanno(node,
                                                    anno.Basic.ORIGIN,
                                                    default=parent_origin)
                    if inherited_origin is not None:
                        anno.setanno(n, anno.Basic.ORIGIN, inherited_origin)

        # On exception, the local scope integrity is not guaranteed.
        if did_enter_function:
            self._enclosing_entities.pop()

        if local_scope_size_at_entry != len(self._local_scope_state):
            raise AssertionError(
                'Inconsistent local scope stack. Before entering node %s, the'
                ' stack had length %d, after exit it has length %d. This'
                ' indicates enter_local_scope and exit_local_scope are not'
                ' well paired.' % (node, local_scope_size_at_entry,
                                   len(self._local_scope_state)))
        return result
Beispiel #32
0
 def assertClosureTypes(self, node, expected):
     actual = anno.getanno(node, anno.Static.CLOSURE_TYPES)
     actual = {str(k): v for k, v in actual.items()}
     for k, v in expected.items():
         self.assertIn(k, actual)
         self.assertEqual(actual[k], v)
 def visit_FunctionDef(self, node):
     node.args = self.generic_visit(node.args)
     node.decorator_list = self.visit_block(node.decorator_list)
     node.body = self._process_block(anno.getanno(node, anno.Static.SCOPE),
                                     node.body)
     return node
Beispiel #34
0
def from_str(qn_str):
    node = parser.parse_expression(qn_str)
    node = resolve(node)
    return anno.getanno(node, anno.Basic.QN)
Beispiel #35
0
    def visit_Call(self, node):
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        function_context_name = self.state[_Function].context_name
        node = self.generic_visit(node)

        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        if full_name.startswith('ag__.'):
            return node

        # Calls to the function context manager (inserted by function_scopes) are
        # also safe.
        if full_name.startswith(function_context_name + '.'):
            return node

        # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use
        # the normal mechanisms to bypass these literals because they are sensitive
        # to the frame they are being called from.
        # TODO(mdan): Generalize this to a "static whitelist" config.
        if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'):
            global set_trace_warned
            if not set_trace_warned:
                # TODO(mdan): Update and shorten once available on tensorflow.org.
                ag_logging.warn(
                    'Detected `pdb.set_trace()` in converted code. The code'
                    ' generated by AutoGraph is not optimized for step-by-step'
                    ' debugging. See https://github.com/tensorflow/tensorflow/'
                    'blob/master/tensorflow/python/autograph/g3doc/reference/'
                    'debugging.md.')
                set_trace_warned = True
            return node

        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return node

        func = node.func

        starred_arg = None
        normal_args = []
        for a in node.args:
            if isinstance(a, gast.Starred):
                assert starred_arg is None, 'Multiple *args should be impossible.'
                starred_arg = a
            else:
                normal_args.append(a)
        if starred_arg is None:
            args = templates.replace_as_expression('(args,)', args=normal_args)
        else:
            args = templates.replace_as_expression('(args,) + tuple(stararg)',
                                                   stararg=starred_arg.value,
                                                   args=normal_args)

        kwargs_arg = None
        normal_keywords = []
        for k in node.keywords:
            if k.arg is None:
                assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
                kwargs_arg = k
            else:
                normal_keywords.append(k)
        if kwargs_arg is None:
            if not normal_keywords:
                kwargs = parser.parse_expression('None')
            else:
                kwargs = ast_util.keywords_to_dict(normal_keywords)
        else:
            kwargs = templates.replace_as_expression(
                'dict(kwargs, **keywords)',
                kwargs=kwargs_arg.value,
                keywords=ast_util.keywords_to_dict(normal_keywords))

        template = """
      ag__.converted_call(func, options, args, kwargs)
    """
        new_call = templates.replace_as_expression(
            template,
            func=func,
            options=parser.parse_expression(function_context_name +
                                            '.callopts'),
            args=args,
            kwargs=kwargs)

        return new_call
Beispiel #36
0
 def assertDifferentAnno(self, first, second, key):
   self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
Beispiel #37
0
    def visit_For(self, node):
        self.generic_visit(node)

        loop_state, reserved_symbols, possibly_undef = self._get_loop_state(
            node)
        loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
            loop_state, reserved_symbols)
        node_body = ast_util.rename_symbols(node.body, ssf_map)
        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test = ast_util.rename_symbols(extra_test, ssf_map)
        else:
            extra_test = parser.parse_expression('True')

        if loop_state:
            template = """
        def extra_test_name(state_ssf):
          return extra_test_expr
        def body_name(loop_vars, state_ssf):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return state_ssf,
        state_ast_tuple = ag__.for_stmt(
            iter_, extra_test_name, body_name, (state,))
      """
            node = templates.replace(template,
                                     state=loop_state,
                                     state_ssf=state_ssf,
                                     state_ast_tuple=state_ast_tuple,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)
        else:
            template = """
        def extra_test_name():
          return extra_test_expr
        def body_name(loop_vars):
          # Workaround for PEP-3113
          iterate = loop_vars
          body
          return ()
        ag__.for_stmt(iter_, extra_test_name, body_name, ())
      """
            node = templates.replace(template,
                                     iter_=node.iter,
                                     iterate=node.target,
                                     extra_test_name=self.ctx.namer.new_symbol(
                                         'extra_test', reserved_symbols),
                                     extra_test_expr=extra_test,
                                     body_name=self.ctx.namer.new_symbol(
                                         'loop_body', reserved_symbols),
                                     body=node_body)

        undefined_assigns = self._create_undefined_assigns(possibly_undef)
        return undefined_assigns + node
 def assertQNStringIs(self, node, qn_str):
     self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
Beispiel #39
0
 def res_call(self, ns, types_ns, node, f_type, args, keywords):
     test_self.assertEqual(f_type, (str, ))
     test_self.assertEqual(anno.getanno(node.func, anno.Basic.QN),
                           qual_names.QN('g'))
     return {float}, None
 def assertNotSameDef(self, first, second):
     self.assertHasDefs(first, 1)
     self.assertHasDefs(second, 1)
     self.assertIsNot(
         anno.getanno(first, anno.Static.DEFINITIONS)[0],
         anno.getanno(second, anno.Static.DEFINITIONS)[0])
 def assertHasDefs(self, node, num):
     defs = anno.getanno(node, anno.Static.DEFINITIONS)
     self.assertEqual(len(defs), num)
     for r in defs:
         self.assertIsInstance(r, reaching_definitions.Definition)
 def visit_With(self, node):
     node.items = self.visit_block(node.items)
     node.body = self._process_block(
         anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
     return node
 def assertClosureTypes(self, node, expected):
     actual = anno.getanno(node, anno.Static.CLOSURE_TYPES)
     actual = {str(k): v for k, v in actual.items()}
     self.assertDictEqual(actual, expected)
Beispiel #44
0
    def visit_For(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars,
         reserved_symbols, possibly_undefs) = self._get_loop_vars(
             node,
             (anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified
              | anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE).modified))
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)
        body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        if anno.hasanno(node, 'extra_test'):
            extra_test = anno.getanno(node, 'extra_test')
            extra_test_name = self.ctx.namer.new_symbol(
                'extra_test', reserved_symbols)
            template = """
        def extra_test_name(loop_vars):
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                extra_test_expr=extra_test)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # Workaround for PEP-3113
        # iterates_var holds a single variable with the iterates, which may be a
        # tuple.
        iterates_var_name = self.ctx.namer.new_symbol('iterates',
                                                      reserved_symbols)
        template = """
      iterates = iterates_var_name
    """
        iterate_expansion = templates.replace(
            template,
            iterates=node.target,
            iterates_var_name=iterates_var_name)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        basic_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name, loop_vars):
          iterate_expansion
          body
          return loop_vars,
        extra_test_function
        loop_vars_ast_tuple = ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        undefined_assigns
        state_functions
        def body_name(iterates_var_name):
          iterate_expansion
          body
          return ()
        extra_test_function
        ag__.for_stmt(
            iter_,
            extra_test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            return templates.replace(
                template,
                undefined_assigns=undefined_assigns,
                iter_=node.iter,
                iterate_expansion=iterate_expansion,
                iterates_var_name=iterates_var_name,
                extra_test_name=extra_test_name,
                extra_test_function=extra_test_function,
                body_name=body_name,
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
Beispiel #45
0
 def assertSameAnno(self, first, second, key):
   self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
Beispiel #46
0
    def visit_Call(self, node):
        # TODO(mdan): Refactor converted_call as a 'Call' operator.

        # Calls to the internal 'ag__' module are never converted (though their
        # arguments might be).
        full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
        if full_name.startswith('ag__.'):
            return self.generic_visit(node)
        if (full_name == 'print' and not self.ctx.program.options.uses(
                converter.Feature.BUILTIN_FUNCTIONS)):
            return self.generic_visit(node)

        if isinstance(node.func, gast.Attribute):
            func = gast.Str(node.func.attr)
            owner = node.func.value
        else:
            func = node.func
            owner = parser.parse_expression('None')

        starred_arg = None
        normal_args = []
        for a in node.args:
            if isinstance(a, gast.Starred):
                assert starred_arg is None, 'Multiple *args should be impossible.'
                starred_arg = a
            else:
                a = self.visit(a)
                normal_args.append(a)
        if starred_arg is None:
            args = templates.replace_as_expression('(args,)', args=normal_args)
        else:
            args = templates.replace_as_expression('(args,) + tuple(stararg)',
                                                   stararg=starred_arg.value,
                                                   args=normal_args)

        kwargs_arg = None
        normal_keywords = []
        for k in node.keywords:
            if k.arg is None:
                assert kwargs_arg is None, 'Multiple **kwargs should be impossible.'
                kwargs_arg = k
            else:
                k = self.visit(k)
                normal_keywords.append(k)
        if kwargs_arg is None:
            if not normal_keywords:
                kwargs = parser.parse_expression('None')
            else:
                kwargs = ast_util.keywords_to_dict(normal_keywords)
        else:
            kwargs = templates.replace_as_expression(
                'dict(kwargs, **keywords)',
                kwargs=kwargs_arg.value,
                keywords=ast_util.keywords_to_dict(normal_keywords))

        template = """
      ag__.converted_call(func, owner, options, args, kwargs)
    """
        new_call = templates.replace_as_expression(
            template,
            func=func,
            owner=owner,
            options=self.ctx.program.options.to_ast(
                internal_convert_user_code=self.ctx.program.options.recursive),
            args=args,
            kwargs=kwargs)

        return new_call
def _live_tensors(f, attr_name="inputs"):
    """Returns the indices of the used inputs.

  Note: This currently only handles direct index accesses e.g. op.inputs[1].
  If the function has slicing or list comprehension on attr_name then returns
  _ALL. This ensure that this is correct even if inefficient.

  Args:
    f: A grad function, taking the op as first argument.
    attr_name: op attr to track. "inputs" or "outputs".

  Returns:
    Either one of:
      * set of integers representing individual indices of inputs used
      * the value _ALL, if indices are used but cannot be determined which
      * empty set, if no inputs are used
  """
    node, _ = parser.parse_entity(f, ())
    entity_info = transformer.EntityInfo(
        name=f.__name__,
        source_code=None,
        source_file=None,
        future_features=(),
        namespace=sys.modules[f.__module__].__dict__)
    ctx = transformer.Context(entity_info, None, None)

    graphs = cfg.build(node)
    node = qual_names.resolve(node)
    node = activity.resolve(node, ctx, None)
    node = reaching_fndefs.resolve(node, ctx, graphs)
    node = liveness.resolve(node, ctx, graphs)

    op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
    op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name)

    special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name, ))
    node = special_tracker.visit(node)

    live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)
    inputs_outputs_used_qns = set()
    for v in special_tracker.complex_reads:
        # Complicated patterns like op.inputs[:3]. Could be smarter about them
        # if they matter much.
        if v == op_inputs_outputs_name:
            return _ALL
    for v in live_vars_in:
        if v in special_tracker.reads:
            if (v.has_subscript() and v.parent == op_inputs_outputs_name):
                inputs_outputs_used_qns.add(v)
            elif v == op_inputs_outputs_name:
                # When op.{attr_name} is used directly, assume all tensors are
                # used for now. In that case, no point digging further.
                # TODO(mdan): We can descend into tuple expansions.
                return _ALL

    function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name)
    node = function_calls_tracker.visit(node)

    input_output_indices = set()

    for called_f in function_calls_tracker.calls:
        child_indices = _live_tensors(called_f, attr_name=attr_name)
        if child_indices is _ALL:
            return _ALL
        input_output_indices |= child_indices

    for v in inputs_outputs_used_qns:
        assert v.has_subscript()
        _, subscript = v.qn
        if not subscript.is_simple():
            # Not a number, assuming it can be anything.
            return _ALL
        subscript_val, = subscript.qn
        if (not isinstance(subscript_val, qual_names.Literal)
                and not isinstance(subscript_val.value, int)):
            # Not a number, assuming it can be anything.
            return _ALL
        input_output_indices.add(subscript_val.value)
    return input_output_indices
Beispiel #48
0
  def visit_While(self, node):
    self.generic_visit(node)

    loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)

    # Note: one might expect we can dispatch based on the loop condition.
    # But because that is dependent on the state, it cannot be evaluated ahead
    # of time - doing that would risk duplicating any effects the condition has.
    # Furthermore, we cannot evaluate slices and attributes, because they might
    # trigger __getitem__ or __getattribute__.
    #
    # A case where this fails includes ops with side effects on a stateful
    # resource captured in an object:
    #
    #   while self.v.read() > 0:
    #     self.v.assign(1)
    #
    # TODO(mdan): Handle the case above.
    cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
    cond_closure = set()
    for s in cond_scope.read:
      cond_closure |= s.support_set

    loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
        loop_state, reserved_symbols)
    node_body = ast_util.rename_symbols(node.body, ssf_map)
    test = ast_util.rename_symbols(node.test, ssf_map)

    if loop_state:
      template = """
        def test_name(state_ssf):
          return test
        def body_name(state_ssf):
          body
          return state_ssf,
        state_ast_tuple = ag__.while_stmt(
            test_name, body_name, (state,), (extra_deps,))
      """
      node = templates.replace(
          template,
          state=loop_state,
          state_ssf=state_ssf,
          state_ast_tuple=state_ast_tuple,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body,
          extra_deps=tuple(s.ast() for s in cond_closure),
      )
    else:
      template = """
        def test_name():
          return test
        def body_name():
          body
          return ()
        ag__.while_stmt(test_name, body_name, (), (extra_deps,))
      """
      node = templates.replace(
          template,
          test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
          test=test,
          body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
          body=node_body,
          extra_deps=tuple(s.ast() for s in cond_closure),
      )

    undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    return undefined_assigns + node
Beispiel #49
0
    def visit_While(self, node):
        node = self.generic_visit(node)

        (basic_loop_vars, composite_loop_vars, reserved_symbols,
         possibly_undefs) = self._get_loop_vars(
             node,
             anno.getanno(node, annos.NodeAnno.BODY_SCOPE).modified)
        loop_vars, loop_vars_ast_tuple = self._loop_var_constructs(
            basic_loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(composite_loop_vars,
                                                       state_getter_name,
                                                       state_setter_name)

        basic_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in basic_loop_vars)
        composite_symbol_names = tuple(
            gast.Str(str(symbol)) for symbol in composite_loop_vars)

        opts = self._create_loop_options(node)

        # TODO(mdan): Use a single template.
        # If the body and test functions took a single tuple for loop_vars, instead
        # of *loop_vars, then a single template could be used.
        if loop_vars:
            template = """
        state_functions
        def body_name(loop_vars):
          body
          return loop_vars,
        def test_name(loop_vars):
          return test
        loop_vars_ast_tuple = ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (loop_vars,),
            (basic_symbol_names,),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                loop_vars=loop_vars,
                loop_vars_ast_tuple=loop_vars_ast_tuple,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                basic_symbol_names=basic_symbol_names,
                composite_symbol_names=composite_symbol_names,
                opts=opts)
        else:
            template = """
        state_functions
        def body_name():
          body
          return ()
        def test_name():
          return test
        ag__.while_stmt(
            test_name,
            body_name,
            state_getter_name,
            state_setter_name,
            (),
            (),
            (composite_symbol_names,),
            opts)
      """
            node = templates.replace(
                template,
                test_name=self.ctx.namer.new_symbol('loop_test',
                                                    reserved_symbols),
                test=node.test,
                body_name=self.ctx.namer.new_symbol('loop_body',
                                                    reserved_symbols),
                body=node.body,
                state_functions=state_functions,
                state_getter_name=state_getter_name,
                state_setter_name=state_setter_name,
                composite_symbol_names=composite_symbol_names,
                opts=opts)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)
        return undefined_assigns + node
Beispiel #50
0
  def visit_If(self, node):
    body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)

    # Note: this information needs to be extracted before the body conversion
    # that happens in the call to generic_visit below, because the conversion
    # generates nodes that lack static analysis annotations.
    need_alias_in_body = self._determine_aliased_symbols(
        body_scope, defined_in, node.body)
    need_alias_in_orelse = self._determine_aliased_symbols(
        orelse_scope, defined_in, node.orelse)

    node = self.generic_visit(node)

    modified_in_cond = body_scope.modified | orelse_scope.modified
    returned_from_cond = set()
    composites = set()
    for s in modified_in_cond:
      if s in live_out and not s.is_composite():
        returned_from_cond.add(s)
      if s.is_composite():
        # Special treatment for compound objects, always return them.
        # This allows special handling within the if_stmt itself.
        # For example, in TensorFlow we need to restore the state of composite
        # symbols to ensure that only effects from the executed branch are seen.
        composites.add(s)

    created_in_body = body_scope.modified & returned_from_cond - defined_in
    created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in

    basic_created_in_body = tuple(
        s for s in created_in_body if not s.is_composite())
    basic_created_in_orelse = tuple(
        s for s in created_in_orelse if not s.is_composite())

    # These variables are defined only in a single branch. This is fine in
    # Python so we pass them through. Another backend, e.g. Tensorflow, may need
    # to handle these cases specially or throw an Error.
    possibly_undefined = (set(basic_created_in_body) ^
                          set(basic_created_in_orelse))

    # Alias the closure variables inside the conditional functions, to allow
    # the functions access to the respective variables.
    # We will alias variables independently for body and orelse scope,
    # because different branches might write different variables.
    aliased_body_orig_names = tuple(need_alias_in_body)
    aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    aliased_body_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
        for s in aliased_body_orig_names)
    aliased_orelse_new_names = tuple(
        self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
        for s in aliased_orelse_orig_names)

    alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    alias_orelse_map = dict(
        zip(aliased_orelse_orig_names, aliased_orelse_new_names))

    node_body = ast_util.rename_symbols(node.body, alias_body_map)
    node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)

    cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
    body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
    all_referenced = body_scope.referenced | orelse_scope.referenced
    state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
    state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)

    returned_from_cond = tuple(returned_from_cond)
    if returned_from_cond:
      if len(returned_from_cond) == 1:
        cond_results = returned_from_cond[0]
      else:
        cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)

      returned_from_body = tuple(
          alias_body_map[s] if s in need_alias_in_body else s
          for s in returned_from_cond)
      returned_from_orelse = tuple(
          alias_orelse_map[s] if s in need_alias_in_orelse else s
          for s in returned_from_cond)

    else:
      # When the cond would return no value, we leave the cond called without
      # results. That in turn should trigger the side effect guards. The
      # branch functions will return a dummy value that ensures cond
      # actually has some return value as well.
      cond_results = None
      # TODO(mdan): Replace with None once side_effect_guards is retired.
      returned_from_body = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)
      returned_from_orelse = (templates.replace_as_expression(
          'ag__.match_staging_level(1, cond_var_name)',
          cond_var_name=cond_var_name),)

    cond_assign = self.create_assignment(cond_var_name, node.test)
    body_def = self._create_cond_branch(
        body_name,
        aliased_orig_names=aliased_body_orig_names,
        aliased_new_names=aliased_body_new_names,
        body=node_body,
        returns=returned_from_body)
    orelse_def = self._create_cond_branch(
        orelse_name,
        aliased_orig_names=aliased_orelse_orig_names,
        aliased_new_names=aliased_orelse_new_names,
        body=node_orelse,
        returns=returned_from_orelse)
    undefined_assigns = self._create_undefined_assigns(possibly_undefined)
    composite_defs = self._create_state_functions(
        composites, state_getter_name, state_setter_name)

    cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
                                       orelse_name, state_getter_name,
                                       state_setter_name)

    return (undefined_assigns + cond_assign + composite_defs + body_def +
            orelse_def + cond_expr)
Beispiel #51
0
def create_source_map(nodes, code, filepath):
    """Creates a source map between an annotated AST and the code it compiles to.

  Note: this function assumes nodes nodes, code and filepath correspond to the
  same code.

  Args:
    nodes: Iterable[ast.AST, ...], one or more AST modes.
    code: Text, the source code in which nodes are found.
    filepath: Text

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse(code, preamble_len=0, single_node=False)
    for node in reparsed_nodes:
        resolve(node, code, filepath, node.lineno, node.col_offset)

    source_map = {}

    try:
        for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
            # Note: generated code might not be mapped back to its origin.
            # TODO(mdan): Generated code should always be mapped to something.
            origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
            final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
            if origin_info is None or final_info is None:
                continue

            # Note: the keys are by line only, excluding the column offset.
            line_loc = LineLocation(final_info.loc.filename,
                                    final_info.loc.lineno)

            existing_origin = source_map.get(line_loc)
            if existing_origin is not None:
                # Overlaps may exist because of child nodes, but almost never to
                # different line locations. Exception make decorated functions, where
                # both lines are mapped to the same line in the AST.

                # Line overlaps: keep bottom node.
                if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                    if existing_origin.loc.lineno >= origin_info.loc.lineno:
                        continue

                # In case of column overlaps, keep the leftmost node.
                if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                    continue

            source_map[line_loc] = origin_info

    except ValueError as err:
        new_msg = 'Inconsistent ASTs detected. This is a bug. Cause: \n'
        new_msg += str(err)
        new_msg += 'Diff:\n'

        for n, rn in zip(nodes, reparsed_nodes):
            nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
            reparsed_nodes_str = pretty_printer.fmt(rn,
                                                    color=False,
                                                    noanno=True)
            diff = difflib.context_diff(nodes_str.split('\n'),
                                        reparsed_nodes_str.split('\n'),
                                        fromfile='Original nodes',
                                        tofile='Reparsed nodes',
                                        n=7)
            diff = '\n'.join(diff)
            new_msg += diff + '\n'
        raise ValueError(new_msg)

    return source_map
Beispiel #52
0
 def visit_FunctionDef(self, node):
     with self.state[_Function] as fn:
         fn.scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
         return self.generic_visit(node)
Beispiel #53
0
  def visit(self, node):
    if not isinstance(node, gast.AST):
      # This is not that uncommon a mistake: various node bodies are lists, for
      # example, posing a land mine for transformers that need to recursively
      # call `visit`.  The error needs to be raised before the exception handler
      # below is installed, because said handler will mess up if `node` is not,
      # in fact, a node.
      msg = ('invalid value for "node": expected "ast.AST", got "{}"; to'
             ' visit lists of nodes, use "visit_block" instead').format(
                 type(node))
      raise ValueError(msg)

    did_enter_function = False
    local_scope_size_at_entry = len(self._local_scope_state)
    processing_expr_node = False

    parent_origin = self.ctx.current_origin
    try:
      if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
        did_enter_function = True
      elif isinstance(node, gast.Expr):
        processing_expr_node = True

      if did_enter_function:
        self._enclosing_entities.append(node)

      if anno.hasanno(node, anno.Basic.ORIGIN):
        self.ctx.current_origin = anno.getanno(node, anno.Basic.ORIGIN)

      if processing_expr_node:
        entry_expr_value = node.value

      if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
        result = super(Base, self).visit(node)
      self.ctx.current_origin = parent_origin

      # Adjust for consistency: replacing the value of an Expr with
      # an Assign node removes the need for the Expr node.
      if processing_expr_node:
        if isinstance(result, gast.Expr) and result.value != entry_expr_value:
          # When the replacement is a list, it is assumed that the list came
          # from a template that contained a number of statements, which
          # themselves are standalone and don't require an enclosing Expr.
          if isinstance(result.value,
                        (list, tuple, gast.Assign, gast.AugAssign)):
            result = result.value

      # On exception, the local scope integrity is not guaranteed.
      if did_enter_function:
        self._enclosing_entities.pop()

      if local_scope_size_at_entry != len(self._local_scope_state):
        raise AssertionError(
            'Inconsistent local scope stack. Before entering node %s, the'
            ' stack had length %d, after exit it has length %d. This'
            ' indicates enter_local_scope and exit_local_scope are not'
            ' well paired.' % (node, local_scope_size_at_entry,
                               len(self._local_scope_state)))
      return result

    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
      if not self.ctx.current_origin:
        raise e
      original_file_path = self.ctx.current_origin.loc.filename
      original_line_number = self.ctx.current_origin.loc.lineno
      original_col_offset = self.ctx.current_origin.loc.col_offset
      original_source_line = (
          self.ctx.current_origin.source_code_line)
      msg = '%s: %s.' % (e.__class__.__name__, str(e))

      # TODO(mdan): Avoid the printing of the original exception.
      # In other words, we need to find how to suppress the "During handling
      # of the above exception, another exception occurred" message.
      six.reraise(
          AutoGraphParseError,
          AutoGraphParseError(msg, (original_file_path, original_line_number,
                                    original_col_offset, original_source_line)),
          sys.exc_info()[2])
Beispiel #54
0
    def visit_For(self, node):
        node = self.generic_visit(node)
        body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
        iter_scope = anno.getanno(node, annos.NodeAnno.ITERATE_SCOPE)

        loop_vars, reserved_symbols, possibly_undefs = self._get_loop_vars(
            node, body_scope.modified | iter_scope.modified)

        undefined_assigns = self._create_undefined_assigns(possibly_undefs)

        nonlocal_declarations = self._create_nonlocal_declarations(loop_vars)

        state_getter_name = self.ctx.namer.new_symbol('get_state',
                                                      reserved_symbols)
        state_setter_name = self.ctx.namer.new_symbol('set_state',
                                                      reserved_symbols)
        state_functions = self._create_state_functions(loop_vars,
                                                       nonlocal_declarations,
                                                       state_getter_name,
                                                       state_setter_name)

        opts = self._create_loop_options(node)
        opts.keys.append(gast.Constant('iterate_names', kind=None))
        opts.values.append(
            gast.Constant(parser.unparse(node.target,
                                         include_encoding_marker=False),
                          kind=None))

        if anno.hasanno(node, anno.Basic.EXTRA_LOOP_TEST):
            extra_test = anno.getanno(node, anno.Basic.EXTRA_LOOP_TEST)
            extra_test_name = self.ctx.namer.new_symbol(
                'extra_test', reserved_symbols)
            template = """
        def extra_test_name():
          nonlocal_declarations
          return extra_test_expr
      """
            extra_test_function = templates.replace(
                template,
                extra_test_expr=extra_test,
                extra_test_name=extra_test_name,
                loop_vars=loop_vars,
                nonlocal_declarations=nonlocal_declarations)
        else:
            extra_test_name = parser.parse_expression('None')
            extra_test_function = []

        # iterate_arg_name holds a single arg with the iterates, which may be a
        # tuple.
        iterate_arg_name = self.ctx.namer.new_symbol('itr', reserved_symbols)
        template = """
      iterates = iterate_arg_name
    """
        iterate_expansion = templates.replace(
            template, iterate_arg_name=iterate_arg_name, iterates=node.target)

        template = """
      state_functions
      def body_name(iterate_arg_name):
        nonlocal_declarations
        iterate_expansion
        body
      extra_test_function
      undefined_assigns
      ag__.for_stmt(
          iterated,
          extra_test_name,
          body_name,
          state_getter_name,
          state_setter_name,
          (symbol_names,),
          opts)
    """
        return templates.replace(
            template,
            body=node.body,
            body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
            extra_test_function=extra_test_function,
            extra_test_name=extra_test_name,
            iterate_arg_name=iterate_arg_name,
            iterate_expansion=iterate_expansion,
            iterated=node.iter,
            nonlocal_declarations=nonlocal_declarations,
            opts=opts,
            symbol_names=tuple(
                gast.Constant(str(s), kind=None) for s in loop_vars),
            state_functions=state_functions,
            state_getter_name=state_getter_name,
            state_setter_name=state_setter_name,
            undefined_assigns=undefined_assigns)
Beispiel #55
0
 def visit_Lambda(self, node):
     with self.state[_Function] as fn:
         fn.scope = anno.getanno(node, anno.Static.SCOPE)
         return self.generic_visit(node)
Beispiel #56
0
 def assertTypes(self, node, expected):
     if not isinstance(expected, tuple):
         expected = expected,
     self.assertSetEqual(set(anno.getanno(node, anno.Static.TYPES)),
                         set(expected))
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, gast.Call):
            # Patterns of single function calls, like:
            #   opt.minimize(loss)
            # or:
            #   tf.py_func(...)

            # First, attempt to gate future evaluation of args. If that's not
            # possible, gate all remaining statements (and that may fail too, see
            # _visit_and_reindent.
            args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
            # NOTE: We can't guard object attributes because they may not be writable.
            # In addition, avoid renaming well-known names.
            # TODO(mdan): Move these names into config.
            unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
            guarded_args = tuple(
                s for s in args_scope.used
                if not s.is_composite() and s not in unguarded_names)

            # TODO(mdan): Include all arguments which depended on guarded_args too.
            # For example, the following will still cause a race:
            #   tf.assign(a, a + 1)
            #   b = a + 1
            #   tf.assign(a, a + 1)  # Control deps here should include `b`
            #   c = b + 1
            # Or maybe we should just raise an "unsafe assign" error?

            if guarded_args:
                # The aliases may need new names to avoid incorrectly making them local.
                # TODO(mdan): This is brutal. It will even rename modules - any fix?
                need_alias = tuple(s for s in guarded_args
                                   if s not in args_scope.parent.modified)
                aliased_new_names = tuple(
                    qual_names.QN(
                        self.ctx.namer.new_symbol(
                            s.ssf(), args_scope.parent.referenced))
                    for s in need_alias)
                alias_map = dict(zip(need_alias, aliased_new_names))
                if len(guarded_args) == 1:
                    s, = guarded_args
                    aliased_guarded_args = alias_map.get(s, s)
                else:
                    aliased_guarded_args = gast.Tuple(
                        [alias_map.get(s, s).ast() for s in guarded_args],
                        None)

                template = """
          with ag__.utils.control_dependency_on_returns(call):
            aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
        """
                control_deps_guard = templates.replace(
                    template,
                    call=node.value,
                    aliased_guarded_args=aliased_guarded_args,
                    guarded_args=guarded_args)[-1]
            else:
                alias_map = {}

                template = """
          with ag__.utils.control_dependency_on_returns(call):
            pass
        """
                control_deps_guard = templates.replace(template,
                                                       call=node.value)[-1]
                control_deps_guard.body = []

            node = control_deps_guard
            anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
                         (node.body, alias_map))
        return node
Beispiel #58
0
    def visit_FunctionDef(self, node):
        self.generic_visit(node)
        kept_decorators = []
        for dec in node.decorator_list:
            if isinstance(dec, gast.Call):
                dec_func = dec.func
            else:
                dec_func = dec

            # Special cases.
            # TODO(mdan): Is there any way we can treat these more generically?
            # We may want to forego using decorators altogether if we can't
            # properly support them.
            if isinstance(dec_func,
                          gast.Name) and dec_func.id in ('classmethod', ):
                # Assumption: decorators are only visible in the AST when converting
                # a function inline (via another decorator).
                # In that case, the converted function is no longer part of the
                # original object that it was declared into.
                # This is currently verified by tests.
                continue

            original_dec = anno.getanno(dec_func, anno.Basic.QN)
            dec_value = anno.getanno(dec_func, 'live_val')

            if dec_value in self.ctx.program.options.strip_decorators:
                continue

            # When using foo.bar.baz, we only really need to grab foo and import
            # that.
            dec_support_node = dec_func
            while isinstance(dec_support_node, gast.Attribute):
                dec_support_node = dec_support_node.value

            if not anno.hasanno(dec_support_node, 'live_val'):
                raise ValueError(
                    'could not resolve symbol "%s" when looking up decorator "%s"'
                    % (anno.getanno(dec_support_node,
                                    anno.Basic.QN), original_dec))

            dec_support = anno.getanno(dec_support_node, 'live_val')
            # The tuple contains:
            #  * the AST that represents the decorator
            #  * the entity supporting the decorator (i.e., what we need to import)
            #  * the name of the module that needs to be imported for this decorator
            #    to properly resolve.
            # Examples:
            #  for foo.bar, the tuple is (<ast>, <module foo>, 'foo')
            #  for baz, the tuple is (<ast>, <module baz.__module__>, 'baz')
            kept_decorators.append(
                (dec, dec_support, anno.getanno(dec_support_node,
                                                anno.Basic.QN)))

        for _, dec_support, name in kept_decorators:
            if tf_inspect.ismodule(dec_support):
                self.ctx.program.additional_imports.add(
                    'import %s as %s' % (dec_support.__name__, name))
            else:
                if dec_support.__module__ == '__main__':
                    raise ValueError(
                        'decorator "%s" was not allowed because it is declared '
                        'in the module "%s". To fix this, declare it in a separate '
                        'module that we can import it from.' %
                        (dec_support, dec_support.__module__))
                self.ctx.program.additional_imports.add(
                    'from %s import %s' % (dec_support.__module__, name))

        node.decorator_list = [dec for dec, _, _ in kept_decorators]
        return node