コード例 #1
0
def to_code(e,
            recursive=True,
            arg_values=None,
            arg_types=None,
            partial_types=None,
            indentation='  '):
    """Return the equivalent of an entity in TensorFlow code.

  See `to_graph` for more details.

  Args:
    e: A Python entity.
    recursive: See to_graph.
    arg_values: See to_graph.
    arg_types: See to_graph.
    partial_types: See to_graph.
    indentation: String, when to use for each level of indentation.

  Returns:
    String.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

    imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS)
    code = '\n'.join(
        compiler.ast_to_source(dep, indentation) for dep in reversed(
            tuple(six.itervalues(conversion_map.dependency_cache))))

    return imports + '\n\n' + code
コード例 #2
0
    def test_entity_to_graph_callable(self):
        def f(a):
            return a

        conversion_map = conversion.ConversionMap(True, (), (), None)
        ast, new_name = conversion.entity_to_graph(f, conversion_map, None,
                                                   None)
        self.assertTrue(isinstance(ast, gast.FunctionDef), ast)
        self.assertEqual('tf__f', new_name)
コード例 #3
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
コード例 #4
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap(True, (), ())
        conversion.entity_to_graph(f, conversion_map, None, None)

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        self.assertEqual(
            'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #5
0
    def test_entity_to_graph_call_tree(self):
        def g(a):
            return a

        def f(a):
            return g(a)

        conversion_map = conversion.ConversionMap(True, (), (), None)
        conversion.entity_to_graph(f, conversion_map, None, None)

        self.assertTrue(f in conversion_map.dependency_cache)
        self.assertTrue(g in conversion_map.dependency_cache)
        self.assertEqual('tf__f', conversion_map.dependency_cache[f].name)
        # need the extra .body[0] in order to step past the with tf.name_scope('f')
        # that is added automatically
        self.assertEqual(
            'tf__g',
            conversion_map.dependency_cache[f].body[0].body[0].value.func.id)
        self.assertEqual('tf__g', conversion_map.dependency_cache[g].name)
コード例 #6
0
 def test_entity_to_graph_unsupported_types(self):
     with self.assertRaises(ValueError):
         conversion_map = conversion.ConversionMap(True, (), (), None)
         conversion.entity_to_graph('dummy', conversion_map, None, None)