def reverse_ad(node, wrt, preserve_result): """Perform reverse-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the generated derivative function should also return the original return value. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg.forward(node, cfg.Active(wrt)) ad = ReverseAD(wrt, preserve_result) pri, adj = ad.visit(node) mod = gast.Module(body=[pri, adj]) mod = annotate.find_stacks(mod) return mod, ad.required, ad.stack
def prepare(self, node): assert isinstance(node, ast.Module) self.env = { 'builtins': __import__('builtins'), } for module_name in MODULES: # __dispatch__ is the only fake top-level module if module_name != '__dispatch__': import_name = module_name alias_module_name = mangle(module_name) self.env[alias_module_name] = __import__(import_name) # handle functions conflicting with c++ keywords for fun in MODULES[module_name]: if fun in ("__theitemgetter__", "pythran"): # these ones do not exist in Python continue # we need to parse the whole code to be able to apply user-defined pure # function but import are resolved before so we remove them to avoid # ImportError (for operator_ for example) dummy_module = ast.Module([s for s in node.body if not isinstance(s, ast.Import)], []) eval(compile(ast.gast_to_ast(dummy_module), '<constant_folding>', 'exec'), self.env) super(ConstantFolding, self).prepare(node)
def to_graph(o, arg_value_hints=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: o: A Python function or class. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap() _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(o): compiled_node.__dict__.update(six.get_function_globals(o)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def filter_code_typevars(module, duc, ancestors): """Create a filtered code with what is needed to create the annotations""" module_filtered = ast.Module() kept = module_filtered.body = [] module_filtered.type_ignores = [] suppressed = set() def fill_suppressed(def_): for user in def_.users(): parent_in_body = ancestors.parents(user.node)[1] suppressed.add(parent_in_body) fill_suppressed(user) for node in module.body: if node in suppressed: continue if isinstance(node, ast.Import): if node.names[0].name in ["transonic", "numpy"]: kept.append(node) else: def_ = duc.chains[node.names[0]] fill_suppressed(def_) # suppressed.add() elif isinstance(node, ast.ImportFrom): if node.module in ["transonic", "numpy"]: kept.append(node) elif isinstance(node, (ast.Assign, ast.AugAssign)): kept.append(node) return extast.unparse(module_filtered)
def prepare(self, node): assert isinstance(node, ast.Module) self.env = { 'builtins': __import__('builtins'), } for module_name in MODULES: # __dispatch__ is the only fake top-level module if module_name != '__dispatch__': alias_module_name = mangle(module_name) try: self.env[alias_module_name] = __import__(module_name) except ImportError: pass # we need to parse the whole code to be able to apply user-defined pure # function but import are resolved before so we remove them to avoid # ImportError (for operator_ for example) dummy_module = ast.Module([s for s in node.body if not isinstance(s, ast.Import)], []) eval(compile(ast.gast_to_ast(dummy_module), '<constant_folding>', 'exec'), self.env) super(ConstantFolding, self).prepare(node)
def to_graph(f, arg_value_hints=None): """Compile a Python function into equivalent TensorFlow code. Args: f: A Python function with arbitrary arguments and return values. arg_value_hints: A dict mapping parameter names to objects that can hint at the type of those parameters. Returns: A function with a signature identical to `f`, but which when executed it creates TF a graph that has the same functionality as the original function. """ conversion_map = conversion.ConversionMap() _, name = conversion.object_to_graph(f, conversion_map, arg_value_hints) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? compiled_node.__dict__.update(six.get_function_globals(f)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def visit_Lambda(self, node): op = issimpleoperator(node) if op is not None: if mangle('operator') not in self.global_declarations: import_ = ast.Import( [ast.alias('operator', mangle('operator'))]) self.imports.append(import_) operator_module = MODULES['operator'] self.global_declarations[mangle('operator')] = operator_module return ast.Attribute( ast.Name(mangle('operator'), ast.Load(), None, None), op, ast.Load()) self.generic_visit(node) forged_name = "{0}_lambda{1}".format(self.prefix, len(self.lambda_functions)) ii = self.gather(ImportedIds, node) ii.difference_update(self.lambda_functions) # remove current lambdas binded_args = [ ast.Name(iin, ast.Load(), None, None) for iin in sorted(ii) ] node.args.args = ( [ast.Name(iin, ast.Param(), None, None) for iin in sorted(ii)] + node.args.args) for patternname, pattern in self.patterns.items(): if issamelambda(pattern, node): proxy_call = ast.Name(patternname, ast.Load(), None, None) break else: duc = ExtendedDefUseChains() nodepattern = deepcopy(node) duc.visit(ast.Module([ast.Expr(nodepattern)], [])) self.patterns[forged_name] = nodepattern, duc forged_fdef = ast.FunctionDef(forged_name, copy(node.args), [ast.Return(node.body)], [], None, None) metadata.add(forged_fdef, metadata.Local()) self.lambda_functions.append(forged_fdef) self.global_declarations[forged_name] = forged_fdef proxy_call = ast.Name(forged_name, ast.Load(), None, None) if binded_args: if MODULES['functools'] not in self.global_declarations.values(): import_ = ast.Import( [ast.alias('functools', mangle('functools'))]) self.imports.append(import_) functools_module = MODULES['functools'] self.global_declarations[mangle( 'functools')] = functools_module return ast.Call( ast.Attribute( ast.Name(mangle('functools'), ast.Load(), None, None), "partial", ast.Load()), [proxy_call] + binded_args, []) else: return proxy_call
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.extend(parser.parse_str(import_line).body) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): for key, val in inspect_utils.getnamespace(e).items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def to_graph(e, recursive=True, verbose=False, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recursively convert any functions that the decorator function may call. verbose: Whether to output the compiled code in the logs. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ program_ctx = converter.ProgramContext( recursive=recursive, autograph_decorators=(convert, do_not_convert, converted_call), partial_types=partial_types, autograph_module=tf_inspect.getmodule(to_graph), uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES) _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values, arg_types) module = gast.Module([]) for dep in reversed(program_ctx.dependency_cache.values()): module.body.append(dep) compiled_node, compiled_src = compiler.ast_to_object( module, source_prefix=program_ctx.required_imports) # The compiled code should see everything the entry entity saw. # TODO(mdan): This might not work well if the call tree spans modules? for key, val in namespace.items(): # Avoid overwriting entities that have been transformed. if key not in compiled_node.__dict__: compiled_node.__dict__[key] = val compiled_fn = getattr(compiled_node, name) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) return compiled_fn
def autodiff_tree(func, wrt, motion, mode, preserve_result, check_dims, verbose): """Perform AD on all functions in a call tree. This function walks the call tree and differentiates each function in it. It also ensures that the global namespaces that each function in the call tree was in are merged. The `tangent` and `numpy` packages are added to the namespace here, so that the gradient templates can assume that they are present. Args: See `grad`. Returns: final: A single module which contains the primals and adjoints of all the functions in the call tree. namespace: A merged dictionary with all the variables in the global namespaces of each function. The primals and adjoints need access to these in order to execute. """ # Imported here to avoid circular imports import tangent namespace = {'tangent': tangent, 'numpy': numpy} done = set() final = gast.Module(body=[]) namespace.update(six.get_function_globals(func)) node, required = autodiff_ast(func, wrt, motion, mode, preserve_result, check_dims, verbose) final.body.extend(node.body) to_do = set(required) if motion == 'split' and mode == 'reverse': done.add((func, wrt)) to_do -= done while to_do: func, wrt = to_do.pop() namespace.update(six.get_function_globals(func)) node, required = autodiff_ast( func=func, wrt=wrt, motion='split', mode=mode, preserve_result=True, check_dims=False, verbose=verbose) final.body.extend(node.body) done.add((func, wrt)) to_do.update(required) to_do -= done return final, namespace
def visit_If(self, node): """Intercepts if statements. Converts each `if` to up to two separate `with` statements, `ProgramBuilder.if_(condition_variable)` and `ProgramBuilder.else_()`. If the incoming `if` had one arm, returns the transformed AST node; if it had two, returns two nodes in a list. Args: node: An `ast.AST` node representing the `if` statement to convert. Returns: then_node: A node representing the `with`-guarded consequent branch. else_node: A node representing the `with`-guarded alternate branch, if present. """ # Transform a branch # NOTE: this is a little hackery to make sure that prepending works # properly. Wrapping a list of statements in a Module ensures # that the AST-visiting machinery won't choke on, e.g., a list. then = self.generic_visit(gast.Module(node.body)).body # Construct header (goes in the `with`s). then_header = templates.replace_as_expression( '_tfp_autobatching_context_.if_(cond)', cond=self._to_reference(node.test)) # Construct `with` node. # TODO(axch): Test that this form actually works with multiline bodies. then_node = templates.replace('with header: body', header=then_header, body=then)[0] if node.orelse: orelse = self.generic_visit(gast.Module(node.orelse)).body orelse_header = templates.replace_as_expression( '_tfp_autobatching_context_.else_()') orelse_node = templates.replace('with header: body', header=orelse_header, body=orelse)[0] # Return both return [then_node, orelse_node] else: return then_node
def build_example(size): tree = gast.Module( body=[gast.Constant(value=i, kind=None) for i in range(size)], type_ignores=[]) py_graph, ast_to_node_id = (py_ast_graphs.py_ast_to_graph(tree)) edges = [] for i in range(1, size, 2): edges.append((ast_to_node_id[id(tree.body[i])], ast_to_node_id[id(tree.body[i - 1])], 1)) return graph_bundle.convert_graph_with_edges( py_graph, edges, py_ast_graphs.BUILDER)
def root_build(body): """Given a list of statements, puts them into a function in a module.""" return gast.Module(body=[ gast.FunctionDef(name="random_function", args=_make_arguments( python_numbers_control_flow.make_name("a"), python_numbers_control_flow.make_name("b")), body=body, decorator_list=[], returns=None, type_comment=None) ], type_ignores=[])
def to_graph(e, recursive=True, arg_values=None, arg_types=None, partial_types=None): """Compile a Python entity into equivalent TensorFlow code. Currently supported entities: * functions * classes Classes are handled by converting all their methods into a new class. Args: e: A Python entity. recursive: Whether to recusrively convert any functions that the decorator function may call. arg_values: A dict containing value hints for symbols like function parameters. arg_types: A dict containing type hints for symbols like function parameters. partial_types: A set of types (e.g. classes) that will not be converted entirely. Calls to member functions for these types will be renamed independently. Returns: A function with a signature identical to `o`, but which when executed it creates TF a graph that has the same functionality as the original entity. """ conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), partial_types=partial_types) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) for import_line in config.COMPILED_IMPORT_STATEMENTS: module.body.append(parser.parse_str(import_line)) for dep in conversion_map.dependency_cache.values(): module.body.append(dep) compiled_node = compiler.ast_to_object(module) # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): compiled_node.__dict__.update(six.get_function_globals(e)) compiled_fn = getattr(compiled_node, name) return compiled_fn
def joint(node): """Merge the bodies of primal and adjoint into a single function. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. Returns: func: A `Module` node with a single function definition containing the combined primal and adjoint. """ node, _, _ = _fix(node) body = node.body[0].body[:-1] + node.body[1].body func = gast.Module(body=[gast.FunctionDef( name=node.body[0].name, args=node.body[1].args, body=body, decorator_list=[], returns=None)]) # Clean up anno.clearanno(func) return func
def prepare(self, node, ctx): assert isinstance(node, ast.Module) self.env = { '__builtin__': __import__('__builtin__'), } for module_name in MODULES: # __dispatch__ is the only fake top-level module if module_name != '__dispatch__': import_name = module_name # handle module name conflicting with c++ keywords if (module_name.endswith("_") and module_name[:-1] in cxx_keywords): import_name = module_name[:-1] alias_module_name = mangle(module_name) self.env[alias_module_name] = __import__(import_name) # handle functions conflicting with c++ keywords for fun in MODULES[module_name]: if fun in ("__theitemgetter__", "pythran"): # these ones do not exist in Python continue # Set attributs pointing to another for C++ keyword # case of __builtin__.int_ that point on __builtin__.int if not hasattr(self.env[alias_module_name], fun): setattr( self.env[alias_module_name], fun, getattr(self.env[alias_module_name], fun.strip("_"))) # we need to parse the whole code to be able to apply user-defined pure # function but import are resolved before so we remove them to avoid # ImportError (for operator_ for example) dummy_module = ast.Module( [s for s in node.body if not isinstance(s, ast.Import)]) eval( compile(ast.gast_to_ast(dummy_module), '<constant_folding>', 'exec'), self.env) super(ConstantFolding, self).prepare(node, ctx)
def forward_ad(node, wrt, preserve_result=False, check_dims=True): """Perform forward-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the original non-differentiated function value should be returned check_dims: A boolean indicating whether the provided derivatives should have the same shape as their corresponding arguments. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg_obj = cfg.CFG.build_cfg(node) cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry) # Build forward mode function fad = ForwardAD(wrt, preserve_result, check_dims) node = fad.visit(node) # Annotate stacks node = annotate.find_stacks(node) # Clean up naive forward-mode fcode node = gast.Module([node]) anno.clearanno(node) return node, fad.required
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module(body=[gast.Expr(value=gast.Num(n=-3))]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)
def class_to_graph(c, program_ctx): """Specialization of `entity_to_graph` for classes.""" converted_members = {} method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m ) members = tf_inspect.getmembers(c, predicate=method_filter) if not members: raise ValueError('Cannot convert %s: it has no member methods.' % c) class_namespace = {} for _, m in members: # Only convert the members that are directly defined by the class. if inspect_utils.getdefiningclass(m, c) is not c: continue node, _, namespace = function_to_graph( m, program_ctx=program_ctx, arg_values={}, arg_types={'self': (c.__name__, c)}, owner_type=c) if class_namespace is None: class_namespace = namespace else: class_namespace.update(namespace) converted_members[m] = node namer = program_ctx.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. # Process any base classes: if the sueprclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). output_nodes = [] renames = {} bases = [] for base in c.__bases__: if isinstance(object, base): bases.append('object') continue if is_whitelisted_for_graph(base): alias = namer.new_symbol(base.__name__, ()) output_nodes.append( gast.ImportFrom( module=base.__module__, names=[gast.alias(name=base.__name__, asname=alias)], level=0)) else: # This will trigger a conversion into a class with this name. alias = namer.compiled_class_name(base.__name__, base) bases.append(alias) renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) program_ctx.update_name_map(namer) # Generate the definition of the converted class. output_nodes.append( gast.ClassDef(class_name, bases=bases, keywords=[], body=list(converted_members.values()), decorator_list=[])) node = gast.Module(output_nodes) # Make a final pass to replace references to the class or its base classes. # Most commonly, this occurs when making super().__init__() calls. # TODO(mdan): Making direct references to superclass' superclass will fail. node = qual_names.resolve(node) renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) node = ast_util.rename_symbols(node, renames) return node, class_name, class_namespace
def visit_Module(self, node): new_node = gast.Module( self._visit(node.body), [] # type_ignores ) return new_node
def visit_Module(self, node): """ Add itertools import for imap, izip or ifilter iterator. """ self.generic_visit(node) importIt = ast.Import(names=[ast.alias(name='itertools', asname=None)]) return ast.Module(body=([importIt] + node.body))
def test_usub(self): orig_ast = gast.ast_to_gast(ast.parse("-3")) target_ast = gast.Module( body=[gast.Expr(value=gast.Constant(value=-3, kind=None))], type_ignores=[]) assert compare_ast(self.canonicalizer.visit(orig_ast), target_ast)