def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, conversion_map.api_module) namer = conversion_map.new_namer(namespace) ctx = context.EntityContext( namer=namer, source_code=source, source_file='<fragment>', namespace=namespace, arg_values=arg_values, arg_types=arg_types, owner_type=owner_type, recursive=conversion_map.recursive) node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') node.name = new_name conversion_map.update_name_map(namer) # TODO(mdan): Use this at compilation. conversion_map.additional_imports.update(deps) return node, new_name
def class_to_graph(c, conversion_map): """Specialization of `entity_to_graph` for classes.""" converted_members = {} members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod) if not members: raise ValueError('Cannot convert %s: it has no member methods.') class_namespace = None for _, m in members: node, _ = function_to_graph( m, conversion_map=conversion_map, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) # TODO(mdan): Do not assume all members have the same view of globals. if class_namespace is None: class_namespace = inspect_utils.getnamespace(m) converted_members[m] = node namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) node = gast.ClassDef( class_name, bases=[], keywords=[], body=converted_members.values(), decorator_list=[]) return node, class_name
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): compiled_node.__dict__.update(inspect_utils.getnamespace(e)) compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def test_getnamespace_hermetic(self): # Intentionally hiding the global function to make sure we don't overwrite # it in the global namespace. free_function = object() # pylint:disable=redefined-outer-name def test_fn(): return free_function ns = inspect_utils.getnamespace(test_fn) globs = six.get_function_globals(test_fn) self.assertTrue(ns['free_function'] is free_function) self.assertFalse(globs['free_function'] is free_function)
def test_getnamespace_locals(self): def called_fn(): return 0 closed_over_list = [] closed_over_primitive = 1 def local_fn(): closed_over_list.append(1) local_var = 1 return called_fn() + local_var + closed_over_primitive ns = inspect_utils.getnamespace(local_fn) self.assertEqual(ns['called_fn'], called_fn) self.assertEqual(ns['closed_over_list'], closed_over_list) self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) self.assertTrue('local_var' not in ns)
def test_getnamespace_globals(self): ns = inspect_utils.getnamespace(factory) self.assertEqual(ns['free_function'], free_function)