예제 #1
0
def parse_entity(entity, future_features):
  """Returns the AST and source code of given entity.

  Args:
    entity: Any, Python function/method/class
    future_features: Iterable[Text], future features to use (e.g.
      'print_statement'). See
      https://docs.python.org/2/reference/simple_stmts.html#future

  Returns:
    gast.AST, Text: the parsed AST node; the source code that was parsed to
    generate the AST (including any prefixes that this function may have added).
  """
  if inspect_utils.islambda(entity):
    return _parse_lambda(entity)

  try:
    original_source = inspect_utils.getimmediatesource(entity)
  except (IOError, OSError) as e:
    raise ValueError(
        'Unable to locate the source code of {}. Note that functions defined'
        ' in certain environments, like the interactive Python shell, do not'
        ' expose their source code. If that is the case, you should define'
        ' them in a .py source file. If you are certain the code is'
        ' graph-compatible, wrap the call using'
        ' @tf.autograph.experimental.do_not_convert. Original error: {}'.format(
            entity, e))

  source = dedent_block(original_source)

  future_statements = tuple(
      'from __future__ import {}'.format(name) for name in future_features)
  source = '\n'.join(future_statements + (source,))

  return parse(source, preamble_len=len(future_features)), source
예제 #2
0
  def compiled_function_name(self,
                             original_fqn,
                             live_entity=None,
                             owner_type=None):
    """See call_trees.FunctionNamer.compiled_function_name."""
    if not self.recursive:
      return None, False

    if (live_entity is not None and inspect_utils.islambda(live_entity)):
      return None, False

    if owner_type is not None and owner_type not in self.partial_types:
      # Members are not renamed when part of an entire converted class.
      return None, False

    if live_entity is not None and live_entity in self.renamed_calls:
      return self.renamed_calls[live_entity], True

    canonical_name = self._as_symbol_name(
        original_fqn, style=_NamingStyle.SNAKE)
    new_name_root = 'tf__%s' % canonical_name
    new_name = new_name_root
    n = 0
    while new_name in self.global_namespace:
      n += 1
      new_name = '%s_%d' % (new_name_root, n)

    if live_entity is not None:
      self.renamed_calls[live_entity] = new_name
    self.generated_names.add(new_name)

    return new_name, True
예제 #3
0
    def compiled_function_name(self,
                               original_fqn,
                               live_entity=None,
                               owner_type=None):
        """See call_trees.FunctionNamer.compiled_function_name."""
        if not self.recursive:
            return None, False

        if (live_entity is not None and inspect_utils.islambda(live_entity)):
            return None, False

        if owner_type is not None and owner_type not in self.partial_types:
            # Members are not renamed when part of an entire converted class.
            return None, False

        if live_entity is not None and live_entity in self.renamed_calls:
            return self.renamed_calls[live_entity], True

        canonical_name = self._as_symbol_name(original_fqn,
                                              style=_NamingStyle.SNAKE)
        new_name_root = 'tf__%s' % canonical_name
        new_name = new_name_root
        n = 0
        while new_name in self.global_namespace:
            n += 1
            new_name = '%s_%d' % (new_name_root, n)

        if live_entity is not None:
            self.renamed_calls[live_entity] = new_name
        self.generated_names.add(new_name)

        return new_name, True
예제 #4
0
 def compiled_function_name(self,
                            original_fqn,
                            live_entity=None,
                            owner_type=None):
   if inspect_utils.islambda(live_entity):
     return None, False
   if owner_type is not None:
     return None, False
   return ('renamed_%s' % '_'.join(original_fqn)), True
 def compiled_function_name(self,
                            original_fqn,
                            live_entity=None,
                            owner_type=None):
     if inspect_utils.islambda(live_entity):
         return None, False
     if owner_type is not None:
         return None, False
     return ('renamed_%s' % '_'.join(original_fqn)), True
예제 #6
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.ctx.program.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.ctx.program.uncompiled_modules:
        return False

    target_entity = self._try_resolve_target(node.func)

    if target_entity is not None:

      # Currently, lambdas are always converted.
      # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...)
      if inspect_utils.islambda(target_entity):
        return True

      # This may be reached when "calling" a callable attribute of an object.
      # For example:
      #
      #   self.fc = tf.keras.layers.Dense()
      #   self.fc()
      #
      for mod in self.ctx.program.uncompiled_modules:
        if target_entity.__module__.startswith(mod[0] + '.'):
          return False

      # Inspect the target function decorators. If any include a @convert
      # or @do_not_convert annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert?
      # To parse and re-analyze each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      # This attribute is set when the decorator was applied before the
      # function was parsed. See api.py.
      if hasattr(target_entity, '__ag_compiled'):
        return False

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_decorator_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.ctx.program.options.strip_decorators):
          return False

    return True
예제 #7
0
    def test_islambda(self):
        def test_fn():
            pass

        self.assertTrue(inspect_utils.islambda(lambda x: x))
        self.assertFalse(inspect_utils.islambda(test_fn))
예제 #8
0
  def test_islambda(self):
    def test_fn():
      pass

    self.assertTrue(inspect_utils.islambda(lambda x: x))
    self.assertFalse(inspect_utils.islambda(test_fn))
예제 #9
0
def _errors_are_normally_possible(entity, error):
    if inspect_utils.islambda(entity) and isinstance(error, ValueError):
        return True
    return False
예제 #10
0
 def test_islambda_renamed_lambda(self):
     l = lambda x: 1
     l.__name__ = 'f'
     self.assertTrue(inspect_utils.islambda(l))