Example #1
0
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]

  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, conversion_map.api_module)
  namer = conversion_map.new_namer(namespace)

  ctx = context.EntityContext(
      namer=namer,
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      owner_type=owner_type,
      recursive=conversion_map.recursive)
  node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  conversion_map.update_name_map(namer)
  # TODO(mdan): Use this at compilation.
  conversion_map.additional_imports.update(deps)

  return node, new_name
Example #2
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=namer or FakeNamer(),
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type,
       recursive=recursive,
       type_annotation_func=utils.set_element_type)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
   self.ctx = ctx
   return node
Example #3
0
  def test_parse_entity(self):

    def f(x):
      return x + 1

    mod, _ = parser.parse_entity(f)
    self.assertEqual('f', mod.body[0].name)
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       owner_type=None,
                       recursive=True):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=namer or FakeNamer(),
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=owner_type,
       recursive=recursive)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   if include_type_analysis:
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
   self.ctx = ctx
   return node
Example #5
0
 def _parse_and_analyze(self, test_fn):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace={},
                                 arg_values=None,
                                 arg_types=None,
                                 recursive=True)
     node = access.resolve(node, ctx)
     return node
Example #6
0
 def _parse_and_analyze(self, test_fn):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace={},
       arg_values=None,
       arg_types=None,
       recursive=True)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   return node
Example #7
0
  def test_parser_compile_idempotent(self):

    def test_fn(x):
      a = True
      b = ''
      if a:
        b = x + 1
      return b

    self.assertEqual(
        textwrap.dedent(tf_inspect.getsource(test_fn)),
        tf_inspect.getsource(
            compiler.ast_to_object(
                parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=True)
     node = qual_names.resolve(node)
     node = activity.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     return node
Example #9
0
 def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       recursive=True)
   node = access.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   return node
Example #10
0
  def _should_compile(self, node, fqn):
    """Determines whether an entity should be compiled in the context."""
    # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
    module_name = fqn[0]
    for mod in self.uncompiled_modules:
      if module_name.startswith(mod[0] + '.'):
        return False

    for i in range(1, len(fqn)):
      if fqn[:i] in self.uncompiled_modules:
        return False

    # Check for local decorations
    if anno.hasanno(node, 'graph_ready'):
      return False

    # The decorators themselves are not to be converted.
    # If present, the decorators should appear as static functions.
    target_entity = self._try_resolve_target(node.func)
    if target_entity is not None:
      # This attribute is set by the decorator itself.
      # TODO(mdan): This may not play nicely with other wrapping decorators.
      if hasattr(target_entity, '__pyct_is_compile_decorator'):
        return False

      if target_entity in self.nocompile_decorators:
        return False

      # Inspect the target function decorators. If any include a @convert
      # or @graph_ready annotation, then they must be called as they are.
      # TODO(mdan): This may be quite heavy.
      # To parse and re-analize each function for every call site could be quite
      # wasteful. Maybe we could cache the parsed AST?
      try:
        target_node, _ = parser.parse_entity(target_entity)
        target_node = target_node.body[0]
      except TypeError:
        # Functions whose source we cannot access are compilable (e.g. wrapped
        # to py_func).
        return True

      for dec in target_node.decorator_list:
        decorator_fn = self._resolve_name(dec)
        if (decorator_fn is not None and
            decorator_fn in self.nocompile_decorators):
          return False

    return True
Example #11
0
def function_to_graph(f,
                      conversion_map,
                      arg_values,
                      arg_types,
                      owner_type=None):
    """Specialization of `entity_to_graph` for callable functions."""
    node, source = parser.parse_entity(f)
    node = node.body[0]
    namespace = six.get_function_globals(f)

    # This is needed for non-global functions.
    closure = six.get_function_closure(f)
    if closure:
        for e in closure:
            if callable(e.cell_contents):
                fn = e.cell_contents
                namespace[fn.__name__] = fn

    # Manually add the utils namespace which may be used from generated code.
    if 'py2tf_util' not in namespace:
        namespace['py2tf_utils'] = utils
    elif namespace['py2tf_utils'] != utils:
        raise ValueError(
            'The module name py2tf_utils is reserved and may not be used.')

    namer = conversion_map.new_namer(namespace)
    ctx = context.EntityContext(namer=namer,
                                source_code=source,
                                source_file='<fragment>',
                                namespace=namespace,
                                arg_values=arg_values,
                                arg_types=arg_types,
                                recursive=conversion_map.recursive)
    node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

    # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
    new_name, did_rename = namer.compiled_function_name(
        f.__name__, f, owner_type)
    if not did_rename:
        new_name = f.__name__
        if node.name != f.__name__:
            raise NotImplementedError(
                'Strange corner case. Send us offending code!')

    node.name = new_name
    conversion_map.update_name_map(namer)
    return node, new_name
Example #12
0
def function_to_graph(f, conversion_map, arg_values, arg_types,
                      owner_type=None):
  """Specialization of `entity_to_graph` for callable functions."""
  node, source = parser.parse_entity(f)
  node = node.body[0]
  namespace = six.get_function_globals(f)

  # This is needed for non-global functions.
  closure = six.get_function_closure(f)
  if closure:
    for e in closure:
      if callable(e.cell_contents):
        fn = e.cell_contents
        namespace[fn.__name__] = fn

  # Manually add the utils namespace which may be used from generated code.
  if 'py2tf_util' not in namespace:
    namespace['py2tf_utils'] = utils
  elif namespace['py2tf_utils'] != utils:
    raise ValueError(
        'The module name py2tf_utils is reserved and may not be used.')

  namer = conversion_map.new_namer(namespace)
  ctx = context.EntityContext(
      namer=namer,
      source_code=source,
      source_file='<fragment>',
      namespace=namespace,
      arg_values=arg_values,
      arg_types=arg_types,
      recursive=conversion_map.recursive)
  node = node_to_graph(node, ctx, conversion_map.nocompile_decorators)

  # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
  new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
  if not did_rename:
    new_name = f.__name__
    if node.name != f.__name__:
      raise NotImplementedError('Strange corner case. Send us offending code!')

  node.name = new_name
  conversion_map.update_name_map(namer)
  return node, new_name
Example #13
0
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        literals=None,
                        arg_types=None):
     literals = literals or {}
     arg_types = arg_types or {}
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=None,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=True)
     node = access.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
     node = type_info.resolve(node, ctx)
     node = live_values.resolve(node, ctx, literals)
     return node
Example #14
0
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        arg_types=None):
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       owner_type=None,
       recursive=True,
       type_annotation_func=utils.set_element_type)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, {})
   return node
Example #15
0
 def parse_and_analyze(self,
                       test_fn,
                       namespace,
                       namer=None,
                       arg_types=None,
                       include_type_analysis=True,
                       recursive=True):
     node, source = parser.parse_entity(test_fn)
     ctx = context.EntityContext(namer=namer,
                                 source_code=source,
                                 source_file=None,
                                 namespace=namespace,
                                 arg_values=None,
                                 arg_types=arg_types,
                                 recursive=recursive)
     node = access.resolve(node, ctx)
     node = live_values.resolve(node, ctx, {})
     if include_type_analysis:
         node = type_info.resolve(node, ctx)
         node = live_values.resolve(node, ctx, {})
     self.ctx = ctx
     return node
 def _parse_and_analyze(self,
                        test_fn,
                        namespace,
                        literals=None,
                        arg_types=None):
   literals = literals or {}
   arg_types = arg_types or {}
   node, source = parser.parse_entity(test_fn)
   ctx = context.EntityContext(
       namer=None,
       source_code=source,
       source_file=None,
       namespace=namespace,
       arg_values=None,
       arg_types=arg_types,
       recursive=True)
   node = qual_names.resolve(node)
   node = activity.resolve(node, ctx)
   node = live_values.resolve(node, ctx, literals)
   node = type_info.resolve(node, ctx)
   node = live_values.resolve(node, ctx, literals)
   return node
Example #17
0
    def test_parse_entity(self):
        def f(x):
            return x + 1

        mod, _ = parser.parse_entity(f)
        self.assertEqual('f', mod.body[0].name)