def dead_code_elimination(node): """Perform a simple form of dead code elimination on a Python AST. This method performs reaching definitions analysis on all function definitions. It then looks for the definition of variables that are not used elsewhere and removes those definitions. This function takes into consideration push and pop statements; if a pop statement is removed, it will also try to remove the accompanying push statement. Note that this *requires dead code elimination to be performed on the primal and adjoint simultaneously*. Args: node: The AST to optimize. Returns: The optimized AST. """ to_remove = set(def_[1] for def_ in annotate.unused(node) if not isinstance(def_[1], (gast.arguments, gast.For))) for n in list(to_remove): for succ in gast.walk(n): if anno.getanno(succ, 'push', False): to_remove.add(anno.getanno(succ, 'push')) transformers.Remove(to_remove).visit(node) anno.clearanno(node) return node
def assignment_propagation(node): """Perform assignment propagation. Assignment propagation is not a compiler optimization as much as a readability optimization. If a variable name is used only once, it gets renamed when possible e.g. `y = x; z = y` will become `z = x`. Args: node: The AST to optimize. Returns: The optimized AST. """ n_reads = read_counts(node) to_remove = [] for succ in gast.walk(node): # We found an assignment of the form a = b # - Left-hand side is a Name, right-hand side is a Name. if (isinstance(succ, gast.Assign) and isinstance(succ.value, gast.Name) and len(succ.targets) == 1 and isinstance(succ.targets[0], gast.Name)): rhs_name = succ.value.id # We now find all the places that b was defined rhs_defs = [ def_[1] for def_ in anno.getanno(succ, 'definitions_in') if def_[0] == rhs_name ] # If b was defined in only one place (not an argument), and wasn't used # anywhere else but in a == b, and was defined as b = x, then we can fold # the statements if (len(rhs_defs) == 1 and isinstance(rhs_defs[0], gast.Assign) and n_reads[rhs_defs[0]] == 1 and isinstance(rhs_defs[0].value, gast.Name) and isinstance(rhs_defs[0].targets[0], gast.Name)): # Mark rhs_def for deletion to_remove.append(rhs_defs[0]) # Propagate the definition succ.value = rhs_defs[0].value # Remove the definitions we folded transformers.Remove(to_remove).visit(node) anno.clearanno(node) return node
def joint(node): """Merge the bodies of primal and adjoint into a single function. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. Returns: func: A `Module` node with a single function definition containing the combined primal and adjoint. """ node, _, _ = _fix(node) body = node.body[0].body[:-1] + node.body[1].body func = gast.Module(body=[gast.FunctionDef( name=node.body[0].name, args=node.body[1].args, body=body, decorator_list=[], returns=None)]) # Clean up anno.clearanno(func) return func
def split(node, stack): """Carry over the state from the primal to the adjoint. Args: node: A module with the primal and adjoint function definitions as returned by `reverse_ad`. stack: The stack node to use for storing and restoring state. Returns: func: A `Module` node with two function definitions containing the primal and adjoint respectively. """ node, defined, reaching = _fix(node) # Store and restore the state node = store_state(node, reaching, defined, stack) # Clean up anno.clearanno(node) return node
def forward_ad(node, wrt, preserve_result=False, check_dims=True): """Perform forward-mode AD on an AST. This function analyses the AST to determine which variables are active and proceeds by taking the naive derivative. Before returning the primal and adjoint it annotates push and pop statements as such. Args: node: A `FunctionDef` AST node. wrt: A tuple of argument indices with respect to which we take the derivative. preserve_result: A boolean indicating whether the original non-differentiated function value should be returned check_dims: A boolean indicating whether the provided derivatives should have the same shape as their corresponding arguments. Returns: mod: A `Module` node containing the naive primal and adjoint of the function which can be fed to the `split` and `joint` functions. required: A list of tuples of functions and argument indices. These functions were called by the function but did not have an adjoint. """ if not isinstance(node, gast.FunctionDef): raise TypeError # Activity analysis cfg_obj = cfg.CFG.build_cfg(node) cfg.Active(range(len(node.args.args))).visit(cfg_obj.entry) # Build forward mode function fad = ForwardAD(wrt, preserve_result, check_dims) node = fad.visit(node) # Annotate stacks node = annotate.find_stacks(node) # Clean up naive forward-mode fcode node = gast.Module([node]) anno.clearanno(node) return node, fad.required