def dead_code_elimination(node): """Perform a simple form of dead code elimination on a Python AST. This method performs reaching definitions analysis on all function definitions. It then looks for the definition of variables that are not used elsewhere and removes those definitions. This function takes into consideration push and pop statements; if a pop statement is removed, it will also try to remove the accompanying push statement. Note that this *requires dead code elimination to be performed on the primal and adjoint simultaneously*. Args: node: The AST to optimize. Returns: The optimized AST. """ to_remove = set(def_[1] for def_ in annotate.unused(node) if not isinstance(def_[1], (gast.arguments, gast.For))) for n in list(to_remove): for succ in gast.walk(n): if anno.getanno(succ, 'push', False): to_remove.add(anno.getanno(succ, 'push')) transformers.Remove(to_remove).visit(node) anno.clearanno(node) return node
def visit(self, node): if node.value: if anno.hasanno(node.value, self.out_label): before = hash(anno.getanno(node.value, self.out_label)) else: before = None preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev if anno.hasanno(pred.value, self.out_label) ] if preds: incoming = functools.reduce(self.op, preds[1:], preds[0]) else: incoming = frozenset() anno.setanno(node.value, self.in_label, incoming, safe=False) gen, kill = self.gen(node, incoming) anno.setanno(node.value, self.gen_label, gen, safe=False) anno.setanno(node.value, self.kill_label, kill, safe=False) anno.setanno(node.value, self.out_label, (incoming - kill) | gen, safe=False) if hash(anno.getanno(node.value, self.out_label)) != before: for succ in node.next: self.visit(succ) else: preds = [ anno.getanno(pred.value, self.out_label) for pred in node.prev ] self.exit = functools.reduce(self.op, preds[1:], preds[0])
def test_defined(): node = tangent.quoting.parse_function(g) cfg.forward(node, cfg.Defined()) body = node.body[0].body # only x is for sure defined at the end assert len(anno.getanno(body[1], 'defined_in')) == 1 # at the end of the if body both x and y are defined if_body = body[0].body assert len(anno.getanno(if_body[0], 'defined_out')) == 2
def visit(self, node): # Remove all AD-generated pushes of unused variables. if anno.hasanno(node, 'push_var') and anno.hasanno( node, 'pop') and anno.hasanno(node, 'gen_push'): defs = frozenset( id_ for id_, node in anno.getanno(node, 'definitions_in')) if ast_.get_name(anno.getanno(node, 'push_var')) not in defs: self.remove(node) self.remove(anno.getanno(node, 'pop')) return super(CleanStack, self).visit(node)
def visit(self, node): if anno.hasanno(node, 'definitions_gen'): self.definitions.update(anno.getanno(node, 'definitions_gen')) self.reaching_definitions = anno.getanno(node, 'definitions_in') if isinstance(node, gast.Name) and isinstance(node.ctx, gast.Load): self.used.update(def_ for def_ in self.reaching_definitions if def_[0] == node.id) super(Unused, self).visit(node) if anno.hasanno(node, 'definitions_gen'): self.reaching_definitions = None
def _init(self, node): gradname = ast_.get_name(node) if anno.hasanno(node, 'adjoint_var'): var = anno.getanno(node, 'adjoint_var') else: var = anno.getanno(node, 'temp_adjoint_var') return gast.Assign(targets=[ gast.Name(id=gradname, ctx=gast.Store(), annotation=None) ], value=gast.Call(func=utils.INIT_GRAD, args=[var], keywords=[]))
def test_reaching(): node = tangent.quoting.parse_function(f) cfg.forward(node, cfg.ReachingDefinitions()) body = node.body[0].body # Only the argument reaches the expression assert len(anno.getanno(body[0], 'definitions_in')) == 1 while_body = body[1].body # x can be either the argument here, or from the previous loop assert len(anno.getanno(while_body[0], 'definitions_in')) == 2 # x can only be the previous line here assert len(anno.getanno(while_body[1], 'definitions_in')) == 1 # x can be the argument here or the last definition from the while body assert len(anno.getanno(body[2], 'definitions_in')) == 2
def visit_FunctionDef(self, node): self.namer = naming.Namer.build(node) # Get the tangent of the body new_body = [] for n in node.body: new = self.visit(n) if isinstance(new, (list, tuple)): new_body.extend(new) else: new_body.append(new) node.body = new_body # Add in the initial gradient argument grad_args = [ create.create_grad(arg, self.namer, tangent=True) for i, arg in enumerate(node.args.args) if i in self.wrt ] if len(self.wrt) != len(grad_args): raise ValueError( 'Mismatch between requested and retrieved derivative arguments. ' 'Requested %d, found %d') % (len(self.wrt), len(grad_args)) node.args.args += grad_args if self.check_dims: # Define the shape check code quote def shape_match_template(primal, tangent_): if not tangent.shapes_match(primal, tangent_): raise ValueError( 'Shape mismatch between argument value (%s) and seed derivative ' '(%s)' \ % (numpy.shape(primal), numpy.shape(tangent_))) # Add a shape check for each seed derivative & primal pair. shape_check_nodes = [] for iwrt, tangent_var in zip(self.wrt, grad_args): primal = node.args.args[iwrt] shape_check = template.replace(shape_match_template, primal=primal, tangent_=tangent_var)[0] shape_check_nodes.append(shape_check) node.body = shape_check_nodes + node.body # Add in gradient initialization statements for everything else grad_init_nodes = [ template.replace('d[z] = init_grad(z)', replace_grad=template.Replace.TANGENT, namer=self.namer, z=arg, init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args) if i not in self.wrt ] node.body = grad_init_nodes + node.body # Rename the function func = anno.getanno(node, 'func') node.name = naming.tangent_name(func, self.wrt) return node
def visit(self, node, abort=astor.codegen.SourceGenerator.abort_visit): if anno.hasanno(node, 'comment'): comment = anno.getanno(node, 'comment') # Preprocess the comment to fit to maximum line width of 80 characters linewidth = 78 if comment['location'] in ('above', 'below'): comment['text'] = comment['text'][:linewidth] n_newlines = 1 if self.new_indentation else 2 if comment['location'] == 'above': self.result.append('\n' * n_newlines) self.result.append(self.indent_with * self.indentation) self.result.append('# %s' % comment['text']) super(SourceWithCommentGenerator, self).visit(node) elif comment['location'] == 'below': super(SourceWithCommentGenerator, self).visit(node) self.result.append('\n') self.result.append(self.indent_with * self.indentation) self.result.append('# %s' % comment['text']) self.result.append('\n' * (n_newlines - 1)) elif comment['location'] == 'right': super(SourceWithCommentGenerator, self).visit(node) self.result.append(' # %s' % comment['text']) else: raise TangentParseError('Only valid comment locations are ' 'above, below, right') else: self.new_indentation = False super(SourceWithCommentGenerator, self).visit(node)
def visit_Expr(self, node): # We need to special-case pushes, e.g. utils.push(_stack,x,op_id) adjoint = [] if (isinstance(node.value, gast.Call) and anno.getanno(node.value, 'func') == utils.push): orig_src = quoting.unquote(node) stack, val, op_id = node.value.args push_template = grads.adjoints[utils.push] adjoint_rhs = template.replace(push_template, namer=self.namer, stack=stack, val=val, op_id=op_id) # Walk the RHS adjoint AST to find temporary adjoint variables to # sum accumulation = template.replace( 'd[xi] = add_grad(d[xi], dxi_partial)', namer=self.namer, replace_grad=template.Replace.FULL, xi=val, dxi_partial=ast_.copy_node(adjoint_rhs[0].targets[0]), add_grad=utils.ADD_GRAD) adjoint = adjoint_rhs + [accumulation] for i, adj in enumerate(adjoint): adjoint[i] = comments.add_comment(adj, 'Grad of: %s' % orig_src) return node, adjoint
def is_active(self, node): active_variables = anno.getanno(node, 'active_in') for succ in gast.walk(node): if (isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load) and succ.id in active_variables): return True return False
def visit_Assign(self, node): if isinstance(node.value, gast.Call) and anno.hasanno( node.value.func, 'add_grad'): defs = frozenset( id_ for id_, node in anno.getanno(node, 'definitions_in')) if ast_.get_name(node.targets[0]) not in defs: node.value = node.value.args[1] return node
def visit(self, node): method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) # Set certain attributes for child nodes if anno.hasanno(node, 'active_in'): self.active_variables = anno.getanno(node, 'active_in') return visitor(node)
def _get_stack_op_handle(node): assert isinstance(node, gast.Call), 'Only can get fn handles of Call nodes' fn_handle = anno.getanno(node, 'func', False) fn_map = defaultdict(lambda: False) fn_map['tangent.pop'] = utils.pop fn_map['tangent.push'] = utils.push fn_map['tangent.pop_stack'] = utils.pop_stack fn_map['tangent.push_stack'] = utils.push_stack if not fn_handle: fn_handle = fn_map[quoting.unquote(node.func)] return fn_handle
def prepend_uninitialized_grads(self, node): if anno.hasanno(node, 'defined_in'): uses = (succ for succ in gast.walk(node) if isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load)) for use in uses: if ((anno.hasanno(use, 'adjoint_var') or anno.hasanno(use, 'temp_adjoint_var')) and use.id not in anno.getanno(node, 'defined_in') and use.id not in self.added): self.added.add(use.id) self.insert_top(self._init(use)) return node
def test_resolve(): def g(x): return 2 * x def f(x): return g(x) node = annotate.resolve_calls(f) assert anno.getanno(node.body[0].body[0].value, 'func') == g def f(x): return h(x) node = quoting.parse_function(f) with pytest.raises(AttributeError): annotate.resolve_calls(f)
def create_grad(node, namer, tangent=False): """Given a variable, create a variable for the gradient. Args: node: A node to create a gradient for, can be a normal variable (`x`) or a subscript (`x[i]`). namer: The namer object which will determine the name to use for the gradient. tangent: Whether a tangent (instead of adjoint) is created. Returns: node: A node representing the gradient with the correct name e.g. the gradient of `x[i]` is `dx[i]`. Note that this returns an invalid node, with the `ctx` attribute missing. It is assumed that this attribute is filled in later. Node has an `adjoint_var` annotation referring to the node it is an adjoint of. """ if not isinstance(node, (gast.Subscript, gast.Name, gast.Str)): raise TypeError if anno.hasanno(node, 'temp_var'): return create_grad(anno.getanno(node, 'temp_var'), namer, tangent) def _name_grad(node): if not isinstance(node, gast.Name): raise TypeError varname = node.id name = namer.grad(varname, tangent) grad_node = gast.Name(id=name, ctx=None, annotation=None) anno.setanno(grad_node, 'adjoint_var', node) return grad_node if isinstance(node, gast.Subscript): grad_node = create_grad(node.value, namer, tangent=tangent) grad_node.ctx = gast.Load() return gast.Subscript(value=grad_node, slice=node.slice, ctx=None) elif isinstance(node, gast.Str): grad_node = create_grad(gast.Name(id=node.s, ctx=None, annotation=None), namer, tangent=tangent) return gast.Str(grad_node.id) else: return _name_grad(node)
def assignment_propagation(node): """Perform assignment propagation. Assignment propagation is not a compiler optimization as much as a readability optimization. If a variable name is used only once, it gets renamed when possible e.g. `y = x; z = y` will become `z = x`. Args: node: The AST to optimize. Returns: The optimized AST. """ n_reads = read_counts(node) to_remove = [] for succ in gast.walk(node): # We found an assignment of the form a = b # - Left-hand side is a Name, right-hand side is a Name. if (isinstance(succ, gast.Assign) and isinstance(succ.value, gast.Name) and len(succ.targets) == 1 and isinstance(succ.targets[0], gast.Name)): rhs_name = succ.value.id # We now find all the places that b was defined rhs_defs = [ def_[1] for def_ in anno.getanno(succ, 'definitions_in') if def_[0] == rhs_name ] # If b was defined in only one place (not an argument), and wasn't used # anywhere else but in a == b, and was defined as b = x, then we can fold # the statements if (len(rhs_defs) == 1 and isinstance(rhs_defs[0], gast.Assign) and n_reads[rhs_defs[0]] == 1 and isinstance(rhs_defs[0].value, gast.Name) and isinstance(rhs_defs[0].targets[0], gast.Name)): # Mark rhs_def for deletion to_remove.append(rhs_defs[0]) # Propagate the definition succ.value = rhs_defs[0].value # Remove the definitions we folded transformers.Remove(to_remove).visit(node) anno.clearanno(node) return node
def active(node, incoming): gen = set() kill = set() if isinstance(node.value, gast.arguments): gen.update(node.value.args[i].id for i in wrt) if isinstance(node.value, gast.Assign): # Special-case e.g. x = tangent.pop(_stack) # such that all values popped off the stack are live. if anno.getanno(node.value.value, 'func', False) == utils.pop: gen.update(ast_.get_updated(node.value)) else: for succ in gast.walk(node.value.value): if isinstance(succ, gast.Name) and succ.id in incoming: gen.update(ast_.get_updated(node.value)) break else: kill.update(ast_.get_updated(node.value)) return gen, kill
def visit_For(self, node): # If the iter is a Name that is active, # we need to rewrite the loop. # Iterators of the form `for a in x` rely on an implicit # indexing operation, which Tangent cannot reverse without # more information. So, we will create an explicit # indexing operation. Note that we will use # integer indexes, which will cause strange behavior if # the iterator's `next()` behavior deviates from # a plain incrementing index. # The right thing to do (eventually) is to write a multiple-dispatch # version of the `next` operator, and its adjoint, so that # we can handle e.g. dicts. if isinstance(node.iter, (gast.Name, gast.Subscript, gast.Attribute)): iter_name = ast.get_name(node.iter) if iter_name in anno.getanno(node, 'active_in'): # for a in x: # f(a) # # becomes # for i in range(len(x)): # a = x[i] # f(a) # Get a unique iterator name old_target = copy.deepcopy(node.target) new_target = quoting.quote(self.namer.unique('_idx')) old_iter = copy.deepcopy(node.iter) item_access = template.replace( 'old_target = x[i]', old_target=old_target, x=old_iter, i=new_target) node.target = gast.Name(id=new_target.id, ctx=gast.Store(), annotation=None) node.iter = quoting.quote('range(len(%s))' % iter_name) anno.setanno(node.iter, 'func', range) anno.setanno(node.iter.args[0], 'func', len) node.body = [item_access] + node.body return node
def is_insert_grad_of_statement(node): """Check whether a context manager calls `insert_grad_of`. Args: node: The context manager node. Returns: Whether or not this node contains `insert_grad_of` calls. Raises: ValueError: If the `insert_grad_of` calls are mixed with other calls. """ tangent_calls = [ anno.getanno(item.context_expr, 'func', None) is utils.insert_grad_of for item in node.items ] if all(tangent_calls): return True elif any(tangent_calls): raise ValueError else: return False
def is_active(self, node): """Checks whether a statement is active. An assignment is active when its right hand side contains active variables. Args: node: an instance of gast.Assign Returns: Whether the statement is active. """ # Special case: If the right hand side is a pop statement, we want to # process it if (isinstance(node.value, gast.Call) and anno.getanno(node.value, 'func', False) == utils.pop): return True for succ in gast.walk(node.value): if (isinstance(succ, gast.Name) and isinstance(succ.ctx, gast.Load) and succ.id in self.active_variables): return True return False
def visit_FunctionDef(self, node): self.namer = naming.Namer.build(node) # Get the tangent of the body new_body = [] for n in node.body: new = self.visit(n) if isinstance(new, (list, tuple)): new_body.extend(new) else: new_body.append(new) node.body = new_body # Add in the initial gradient argument grad_args = [ create.create_grad(arg, self.namer, tangent=True) for i, arg in enumerate(node.args.args) if i in self.wrt ] node.args.args += grad_args # Add in gradient initialization statements for everything else grad_init_nodes = [ template.replace( 'd[z] = init_grad(z)', replace_grad=template.Replace.TANGENT, namer=self.namer, z=arg, init_grad=utils.INIT_GRAD) for i, arg in enumerate(node.args.args) if i not in self.wrt ] node.body = grad_init_nodes + node.body # Rename the function func = anno.getanno(node, 'func') node.name = naming.tangent_name(func, self.wrt) return node
def remove_repeated_comments(node): """Remove comments that repeat themselves. Multiple statements might be annotated with the same comment. This way if one of the statements is deleted during optimization passes, the comment won't be lost. This pass removes sequences of identical comments, leaving only the first one. Args: node: An AST Returns: An AST where comments are not repeated in sequence. """ last_comment = {'text': None} for _node in gast.walk(node): if anno.hasanno(_node, 'comment'): comment = anno.getanno(_node, 'comment') if comment['text'] == last_comment['text']: anno.delanno(_node, 'comment') last_comment = comment return node
def visit_Expr(self, node): # Special-case the push() expression (have to reverse the usual # tangent/primal order) if isinstance(node.value, gast.Call): fn = anno.getanno(node.value, 'func') if fn in [tangent.push, tangent.push_stack]: # Save the pop associated with this push template_ = tangents.tangents[fn] stack_node, var_node, op_id = node.value.args tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, x=var_node, stack=stack_node, op_id=op_id) # Push the original node and the tangent_node push # onto a "meta-stack", so we can track the relationship # between the pushes and pops in tangent mode self.metastack.append(tangent_node[0]) return [node] + tangent_node return node
def visit(self, node): """Visit a node. This method is largely modelled after the ast.NodeTransformer class. Args: node: The node to visit. Returns: A tuple of the primal and adjoint, each of which is a node or a list of nodes. """ method = 'visit_' + node.__class__.__name__ if not hasattr(self, method): raise ValueError('Unknown node type: %s' % node.__class__.__name__) visitor = getattr(self, method) # If this node is a statement, inform all child nodes what the active # variables in this statement are if anno.hasanno(node, 'active_in'): self.active_variables = anno.getanno(node, 'active_in') pri, adj = visitor(node) # Annotate primal and adjoint statements if isinstance(pri, gast.AST): anno.setdefaultanno(pri, 'adj', adj) else: for node in pri: anno.setdefaultanno(node, 'adj', adj) if isinstance(adj, gast.AST): anno.setdefaultanno(adj, 'pri', pri) else: for node in adj: anno.setdefaultanno(node, 'pri', pri) return pri, adj
def visit(self, node): if anno.hasanno(node, 'push_var'): varname = ast_.get_name(anno.getanno(node, 'push_var')) if varname not in anno.getanno(node, 'defined_in'): self.insert_top(quoting.quote('{} = None'.format(varname))) return super(FixStack, self).visit(node)
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
def test_active2(): node = tangent.quoting.parse_function(i) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # through y both x and z are now active assert len(anno.getanno(body[-1], 'active_out')) == 3
def test_active(): node = tangent.quoting.parse_function(h) cfg.forward(node, cfg.Active(wrt=(1, ))) body = node.body[0].body # y has been overwritten here, so nothing is active anymore assert not anno.getanno(body[-1], 'active_out')