Example #1
0
def to_graph(f, arg_value_hints=None):
    """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
    conversion_map = conversion.ConversionMap()
    _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    compiled_node.__dict__.update(six.get_function_globals(f))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Example #2
0
def replace(template, **replacements):
  """Replace placeholders in a Python template.

  AST Name and Tuple nodes always receive the context that inferred from
  the template. However, when replacing more complex nodes (that can potentially
  contain Name children), then the caller is responsible for setting the
  appropriate context.

  Args:
    template: A string representing Python code. Any symbol name can be used
        that appears in the template code can be used as placeholder.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. String values are also
        supported as a shorthand for AST Name nodes with the respective ID.

  Returns:
    An AST node or list of AST nodes with the replacements made. If the
    template was a function, a list will be returned. If the template was a
    node, the same node will be returned. If the template was a string, an
    AST node will be returned (a `Module` node in the case of a multi-line
    string, an `Expr` node otherwise).

  Raises:
    ValueError: if the arguments are incorrect.
  """
  if not isinstance(template, str):
    raise ValueError('Expected string template, got %s' % type(template))
  tree = parser.parse_str(template)
  for k in replacements:
    replacements[k] = _strings_to_names(replacements[k])
  return ReplaceTransformer(replacements).visit(tree).body
Example #3
0
def to_graph(o, arg_value_hints=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    o: A Python function or class.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(o):
    compiled_node.__dict__.update(six.get_function_globals(o))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Example #4
0
  def test_subscript_resolve(self):
    samples = """
      x[i]
      x[i.b]
      a.b[c]
      a.b[x.y]
      a[z[c]]
      a[b[c[d]]]
      a[b].c
      a.b.c[d].e.f
      a.b[c[d]].e.f
      a.b[c[d.e.f].g].h
    """
    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
    nodes = tuple(n.value for n in nodes.body)

    self.assertQNStringIs(nodes[0], 'x[i]')
    self.assertQNStringIs(nodes[1], 'x[i.b]')
    self.assertQNStringIs(nodes[2], 'a.b[c]')
    self.assertQNStringIs(nodes[3], 'a.b[x.y]')
    self.assertQNStringIs(nodes[4], 'a[z[c]]')
    self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
    self.assertQNStringIs(nodes[6], 'a[b].c')
    self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
    self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
    self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
 def visit_UnaryOp(self, node):
   node = self.generic_visit(node)
   if isinstance(node.op, gast.Not):
     tf_function = parser.parse_str(self.op_mapping[type(
         node.op)]).body[0].value
     node = gast.Call(func=tf_function, args=[node.operand], keywords=[])
   return node
Example #6
0
 def test_parse_str(self):
   mod = parser.parse_str(
       textwrap.dedent("""
           def f(x):
             return x + 1
   """))
   self.assertEqual('f', mod.body[0].name)
Example #7
0
 def test_parse_str(self):
     mod = parser.parse_str(
         textwrap.dedent("""
         def f(x):
           return x + 1
 """))
     self.assertEqual('f', mod.body[0].name)
Example #8
0
def replace(template, **replacements):
    """Replace placeholders in a Python template.

  AST Name and Tuple nodes always receive the context that inferred from
  the template. However, when replacing more complex nodes (that can potentially
  contain Name children), then the caller is responsible for setting the
  appropriate context.

  Args:
    template: A string representing Python code. Any symbol name can be used
        that appears in the template code can be used as placeholder.
    **replacements: A mapping from placeholder names to (lists of) AST nodes
        that these placeholders will be replaced by. String values are also
        supported as a shorthand for AST Name nodes with the respective ID.

  Returns:
    An AST node or list of AST nodes with the replacements made. If the
    template was a function, a list will be returned. If the template was a
    node, the same node will be returned. If the template was a string, an
    AST node will be returned (a `Module` node in the case of a multi-line
    string, an `Expr` node otherwise).

  Raises:
    ValueError: if the arguments are incorrect.
  """
    if not isinstance(template, str):
        raise ValueError('Expected string template, got %s' % type(template))
    tree = parser.parse_str(textwrap.dedent(template))
    for k in replacements:
        replacements[k] = _convert_to_ast(replacements[k])
    results = ReplaceTransformer(replacements).visit(tree).body
    if isinstance(results, list):
        return [qual_names.resolve(r) for r in results]
    return qual_names.resolve(results)
Example #9
0
def to_graph(f, arg_value_hints=None):
  """Compile a Python function into equivalent TensorFlow code.

  Args:
    f: A Python function with arbitrary arguments and return values.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `f`, but which when executed it
  creates TF a graph that has the same functionality as the original function.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  compiled_node.__dict__.update(six.get_function_globals(f))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Example #10
0
    def test_subscript_resolve(self):
        samples = """
      x[i]
      x[i.b]
      a.b[c]
      a.b[x.y]
      a[z[c]]
      a[b[c[d]]]
      a[b].c
      a.b.c[d].e.f
      a.b[c[d]].e.f
      a.b[c[d.e.f].g].h
    """
        nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
        nodes = tuple(n.value for n in nodes.body)

        self.assertQNStringIs(nodes[0], 'x[i]')
        self.assertQNStringIs(nodes[1], 'x[i.b]')
        self.assertQNStringIs(nodes[2], 'a.b[c]')
        self.assertQNStringIs(nodes[3], 'a.b[x.y]')
        self.assertQNStringIs(nodes[4], 'a[z[c]]')
        self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
        self.assertQNStringIs(nodes[6], 'a[b].c')
        self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
        self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
        self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
 def visit_BoolOp(self, node):
   # TODO(mdan): A normalizer may be useful here. Use ANF?
   tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
   left = node.values[0]
   for i in range(1, len(node.values)):
     left = gast.Call(
         func=tf_function, args=[left, node.values[i]], keywords=[])
   return left
 def visit_BoolOp(self, node):
   # TODO(mdan): A normalizer may be useful here. Use ANF?
   node = self.generic_visit(node)
   tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
   left = node.values[0]
   for i in range(1, len(node.values)):
     left = gast.Call(
         func=tf_function, args=[left, node.values[i]], keywords=[])
   return left
Example #13
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
 def visit_Compare(self, node):
   node = self.generic_visit(node)
   if len(node.ops) > 1:
     raise NotImplementedError()
   cmp_type = type(node.ops[0])
   if cmp_type in self.op_mapping:
     tf_function = parser.parse_str(self.op_mapping[cmp_type]).body[0].value
     return gast.Call(
         func=tf_function, args=[node.left, node.comparators[0]], keywords=[])
   return node
Example #15
0
 def test_keywords_to_dict(self):
     keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
     d = ast_util.keywords_to_dict(keywords)
     # Make sure we generate a usable dict node by attaching it to a variable and
     # compiling everything.
     output = parser.parse_str('b = 3')
     output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d), )
     result, _ = compiler.ast_to_object(output)
     self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
     print(d)
Example #16
0
 def test_keywords_to_dict(self):
   keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
   d = ast_util.keywords_to_dict(keywords)
   # Make sure we generate a usable dict node by attaching it to a variable and
   # compiling everything.
   output = parser.parse_str('b = 3')
   output.body += (ast.Assign([ast.Name(id='d', ctx=ast.Store())], d),)
   result, _ = compiler.ast_to_object(output)
   self.assertDictEqual(result.d, {'a': 3, 'c': 1, 'd': 'e'})
   print(d)
Example #17
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    compiled_node.__dict__.update(inspect_utils.getnamespace(e))
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Example #18
0
  def test_to_code_basic(self):
    def test_fn(x, s):
      while math_ops.reduce_sum(x) > s:
        x /= 2
      return x

    compiled_code = api.to_code(test_fn)

    # Just check for some key words and that it is parseable Python code.
    self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while')
    self.assertIsNotNone(parser.parse_str(compiled_code))
Example #19
0
  def test_to_code_basic(self):
    def test_fn(x, s):
      while math_ops.reduce_sum(x) > s:
        x /= 2
      return x

    compiled_code = api.to_code(test_fn)

    # Just check for some key words and that it is parseable Python code.
    self.assertRegexpMatches(compiled_code, 'py2tf_utils\\.run_while')
    self.assertIsNotNone(parser.parse_str(compiled_code))
Example #20
0
    def test_to_code_basic(self):
        def test_fn(x, s):
            while math_ops.reduce_sum(x) > s:
                x /= 2
            return x

        config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__, ))
        compiled_code = api.to_code(test_fn)

        # Just check for some key words and that it is parseable Python code.
        self.assertRegexpMatches(compiled_code, 'tf\\.while_loop')
        self.assertIsNotNone(parser.parse_str(compiled_code))
Example #21
0
 def test_function_calls(self):
     samples = """
   a.b
   a.b()
   a().b
   z[i]
   z[i]()
   z()[i]
 """
     nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
     nodes = tuple(n.value for n in nodes.body)
     self.assertQNStringIs(nodes[0], 'a.b')
     self.assertQNStringIs(nodes[1].func, 'a.b')
     self.assertQNStringIs(nodes[2].value.func, 'a')
     self.assertQNStringIs(nodes[3], 'z[i]')
     self.assertQNStringIs(nodes[4].func, 'z[i]')
     self.assertQNStringIs(nodes[5].value.func, 'z')
Example #22
0
 def test_function_calls(self):
   samples = """
     a.b
     a.b()
     a().b
     z[i]
     z[i]()
     z()[i]
   """
   nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
   nodes = tuple(n.value for n in nodes.body)
   self.assertQNStringIs(nodes[0], 'a.b')
   self.assertQNStringIs(nodes[1].func, 'a.b')
   self.assertQNStringIs(nodes[2].value.func, 'a')
   self.assertQNStringIs(nodes[3], 'z[i]')
   self.assertQNStringIs(nodes[4].func, 'z[i]')
   self.assertQNStringIs(nodes[5].value.func, 'z')
Example #23
0
    def test_resolve(self):
        samples = """
      a
      a.b
      (c, d.e)
      [f, (g.h.i)]
      j(k, l)
    """
        nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
        nodes = tuple(n.value for n in nodes.body)

        self.assertQNStringIs(nodes[0], 'a')
        self.assertQNStringIs(nodes[1], 'a.b')
        self.assertQNStringIs(nodes[2].elts[0], 'c')
        self.assertQNStringIs(nodes[2].elts[1], 'd.e')
        self.assertQNStringIs(nodes[3].elts[0], 'f')
        self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
        self.assertQNStringIs(nodes[4].func, 'j')
        self.assertQNStringIs(nodes[4].args[0], 'k')
        self.assertQNStringIs(nodes[4].args[1], 'l')
Example #24
0
  def test_resolve(self):
    samples = """
      a
      a.b
      (c, d.e)
      [f, (g.h.i)]
      j(k, l)
    """
    nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
    nodes = tuple(n.value for n in nodes.body)

    self.assertQNStringIs(nodes[0], 'a')
    self.assertQNStringIs(nodes[1], 'a.b')
    self.assertQNStringIs(nodes[2].elts[0], 'c')
    self.assertQNStringIs(nodes[2].elts[1], 'd.e')
    self.assertQNStringIs(nodes[3].elts[0], 'f')
    self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
    self.assertQNStringIs(nodes[4].func, 'j')
    self.assertQNStringIs(nodes[4].args[0], 'k')
    self.assertQNStringIs(nodes[4].args[1], 'l')
 def instantiate_list_node(self):
     return parser.parse_str('[]').body[0].value
 def instantiate_list_node(self):
   return parser.parse_str('[]').body[0].value