def test_compile(): def f(x): return x * 2 f = compile_.compile_function(quoting.parse_function(f)) assert f(2) == 4 assert inspect.getsource(f).split('\n')[0] == 'def f(x):' def f(x): return y * 2 f = compile_.compile_function(quoting.parse_function(f), {'y': 3}) assert f(2) == 6
def resolve_calls(func): """Parse a function into an AST with function calls resolved. Since the calls are resolved using the global and local namespace of the function it means that procedural parameters (i.e. functions passed as arguments) won't be resolved. Similarly, functions defined inside of the function that we are trying to resolve won't be resolved, since they are not in the local namespace of the outer function. The function definition itself is also annotated, so that it can be matched to calls to it in other functions. Args: func: The function whose calls are being resolved. Returns: node: An AST where each `Call` node has a `func` annotation with the function handle that the call resolves to. Raises: AttributeError: When a function is used on the RHS of an assignment cannot be resolved (because it was passed as an argument or was defined in the body of the function). """ node = quoting.parse_function(func) ResolveCalls(func).visit(node) return node
def _wrap(body): """Take a list of statements and wrap them in a function to compile.""" def f(): pass tree = quoting.parse_function(f) tree.body[0].body = body return tree
def test_full_gradient_replace(): def f(x, y): d[x] = d[y] tree = quoting.parse_function(f) transformer = template.ReplaceGradTransformer(template.Replace.FULL) new_tree = transformer.visit(tree) assert isinstance(new_tree.body[0].body[0].targets[0], gast.Name) assert new_tree.body[0].body[0].targets[0].id == 'bx' assert new_tree.body[0].body[0].value.id == 'by' compile_.compile_function(new_tree)
def test_unused(): def f(x): y = x * 2 return x node = quoting.parse_function(f) unused = annotate.unused(node) assert unused == set([('y', node.body[0].body[0])]) def f(x): y = x * 2 return y unused = annotate.unused(quoting.parse_function(f)) assert not unused def f(x): while True: y = x * 2 x = 3 return y unused = annotate.unused(quoting.parse_function(f)) assert not unused
def test_resolve(): def g(x): return 2 * x def f(x): return g(x) node = annotate.resolve_calls(f) assert anno.getanno(node.body[0].body[0].value, 'func') == g def f(x): return h(x) node = quoting.parse_function(f) with pytest.raises(AttributeError): annotate.resolve_calls(f)
def test_comment(): node = quoting.parse_function(f).body[0] comments.add_comment(node.body[0], 'foo', 'above') source = quoting.to_source(node) lines = source.split('\n') assert lines[1].strip() == '# foo' comments.add_comment(node.body[0], 'foo', 'right') source = quoting.to_source(node) lines = source.split('\n') assert lines[1].strip() == 'y = x # foo' comments.add_comment(node.body[0], 'foo', 'below') source = quoting.to_source(node) lines = source.split('\n') assert lines[2].strip() == '# foo'
def test_insert(): def f(x): y = x return y node = quoting.parse_function(f) class Prepend(transformers.TreeTransformer): def visit_Assign(self, node): # If the target is y, then prepend this statement # NOTE Without this test, we'd have an infinite loop if node.targets[0].id == 'y': statement = quoting.quote("x = 2 * x") self.prepend(statement) return node Prepend().visit(node) assert quoting.unquote(node).split('\n')[1].strip() == "x = 2 * x"
def test_remove(): def f(x): while True: y = x z = y return y node = quoting.parse_function(f) class InsertTop(transformers.TreeTransformer): def visit_Assign(self, node): # If the target is y, then prepend this statement # NOTE Without this test, we'd have an infinite loop if node.targets[0].id == 'z': self.remove(node) return node InsertTop().visit(node) assert quoting.unquote(node).split('\n')[3].strip() == "return y"
def _create_joint(fwdbwd, func, wrt, input_derivative): """Create a user-friendly gradient function. By default, gradient functions expect the stack to be passed to them explicitly. This function modifies the function so that the stack doesn't need to be passed and gets initialized in the function body instead. For consistency, gradient functions always return a tuple, even if the gradient of only one input was required. We unpack the tuple if it is of length one. Args: fwdbwd: An AST. The function definition of the joint primal and adjoint. func: A function handle. The original function that was differentiated. wrt: A tuple of integers. The arguments with respect to which we differentiated. Returns: The function definition of the new function. """ # Correct return to be a non-tuple if there's only one element retval = fwdbwd.body[-1] if len(retval.value.elts) == 1: retval.value = retval.value.elts[0] # Make a stack init statement init_stack = quoting.quote('%s = tangent.Stack()' % fwdbwd.args.args[0].id) init_stack = comments.add_comment(init_stack, 'Initialize the tape') # Prepend the stack init to the top of the function fwdbwd.body = [init_stack] + fwdbwd.body # Replace the function arguments with the original ones grad_name = fwdbwd.args.args[1].id fwdbwd.args = quoting.parse_function(func).body[0].args # Give the function a nice name fwdbwd.name = naming.joint_name(func, wrt) # Allow the initial gradient to be passed as a keyword argument fwdbwd = ast_.append_args(fwdbwd, [grad_name]) if input_derivative == INPUT_DERIVATIVE.DefaultOne: fwdbwd.args.defaults.append(quoting.quote('1.0')) return fwdbwd
def test_anf(): def g(x): return x * 2 h = g def f(x): y = g(h(x)) return y assert anf_lines(f)[1].strip() == "h_x = h(x)" assert anf_function(f, locals())(2) == 8 def f(x): return x * x * x assert 'return' in anf_lines(f)[-1] and '*' not in anf_lines(f)[-1] assert anf_function(f)(2) == 8 def f(x): y = [(x.y[0], ), 3] y += x * f(x[g(x)].b, (3, x / -2)) assert anf.anf(quoting.parse_function(f))
def replace(template, replace_grad=Replace.PARTIAL, namer=None, **replacements): """Replace placeholders in a Python template (quote). Args: template: A function, AST node or string to be used as a template. Note that if a function is passed, any placeholder is expected to also be a function argument. If a string is passed, it must represent valid Python code, and any variable it references is a placeholder. replace_grad: If Replace.NONE, statements of the form `d[x]` are ignored. For the other possible values, see `ReplaceGradTransformer`. namer: See `ReplaceGradTransformer`. **replacements: A mapping from placeholder names to (lists of) AST nodes that these placeholders will be replaced by. If a string is passed, `quote` will be called on it to turn it into a node. Returns: body: An AST node or list of AST nodes with the replacements made. If the template was a function, a list will be returned. If the template was a node, the same node will be returned. If the template was a string, an AST node will be returned (a `Module` node in the case of a multi-line string, an `Expr` node otherwise). Raises: ValueError: If a function is used as a template and an incorrect set of replacements was passed. """ # Handle the 3 different types of templates: funcs, nodes, and strings is_function = isinstance(template, types.FunctionType) if is_function: tree = quoting.parse_function(template).body[0] placeholders = set(arg.id for arg in tree.args.args) tree.args.args = [] if tree.args.vararg: placeholders.add(tree.args.vararg) tree.args.vararg = None if set(replacements.keys()) != placeholders: raise ValueError('too many or few replacements') elif isinstance(template, gast.AST): tree = template else: tree = quoting.quote(template, return_expr=True) # If the replacements are strings, turn them into nodes for k, v in replacements.items(): if isinstance(v, six.string_types): replacements[k] = quoting.quote(v) # Perform the replacement ReplaceTransformer(replacements).visit(tree) # Handle the d[x] operator if replace_grad is not Replace.NONE: rgt = ReplaceGradTransformer(replace_grad=replace_grad, namer=namer, tangent=replace_grad is Replace.TANGENT) rgt.visit(tree) # Return the AST node with replacements made if is_function: return tree.body else: return tree
def anf_lines(f): """Return the ANF transformed source code as lines.""" return quoting.unquote(anf.anf(quoting.parse_function(f))).split('\n')
def _assert_tangent_parse_error(func, fragment): try: fence.validate(quoting.parse_function(func), inspect.getsource(func)) assert False except fence.TangentParseError as expected: assert fragment in str(expected)
def visit_Call(self, node): """Create adjoint for call. We don't allow unpacking of parameters, so we know that each argument gets passed in explicitly, allowing us to create partials for each. However, templates might perform parameter unpacking (for cases where the number of arguments is variable) and express their gradient as a tuple. In this case, we have to unpack this tuple of partials. """ # Find the function we are differentiating func = anno.getanno(node, 'func') if func in non_differentiable.NON_DIFFERENTIABLE: return node, [] if func == tracing.Traceable: return self.primal_and_adjoint_for_tracing(node) if func in grads.UNIMPLEMENTED_ADJOINTS: raise errors.ReverseNotImplementedError(func) # If we don't have an adjoint, we will have to step into the called # function and differentiate it if func not in grads.adjoints: active_args = tuple(i for i, arg in enumerate(node.args) if arg.id in self.active_variables) already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) pri_name = naming.primal_name(func, active_args) pri_call = gast.Call( func=gast.Name(id=pri_name, ctx=gast.Load(), annotation=None), args=[self.substack] + node.args, keywords=node.keywords) anno.setanno(pri_call, 'pri_call', True) dy = create.create_grad(self.target, self.namer) dy.ctx = gast.Load() dx = create.create_grad(node.args[0], self.namer) dx.ctx = gast.Store() adj_name = naming.adjoint_name(func, active_args) adj_call = gast.Call( func=gast.Name(id=adj_name, ctx=gast.Load(), annotation=None), args=[self.substack, dy] + node.args, keywords=node.keywords) anno.setanno(adj_call, 'adj_call', True) adjoint = [template.replace('dxs = dfx', namer=self.namer, dfx=adj_call)] for j, i in enumerate(active_args): adjoint.append(template.replace('d[x] = dxs[i]', namer=self.namer, x=node.args[i].id, i=gast.Num(n=j))) return pri_call, adjoint # We have a template for the gradient that we need to fill in template_ = grads.adjoints[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) # Fill in any missing kwargs with the defaults from the template args = quoting.parse_function(template_).body[0].args kwargs = dict(zip(*map(reversed, [args.args, args.defaults]))) kwargs.update(dict(zip(args.kwonlyargs, args.kw_defaults))) for arg, val in kwargs.items(): if arg.id not in bound_args.arguments: bound_args.arguments[arg.id] = val # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: ast_.copy_node(self.target)} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs packing = [] flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) packing = [gast.Assign(targets=[target], value=value)] # And we fill in the packed tuple into the template arg_replacements[six.get_function_code( template_).co_varnames[-1]] = target adjoint = template.replace(template_, namer=self.namer, **arg_replacements) unpacking = [] if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [create.create_temp_grad(arg, self.namer) for arg in to_pack] value = create.create_grad(target, self.namer) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) unpacking = [gast.Assign(targets=[target], value=value)] return node, packing + adjoint + unpacking
def test_parses(func): """Test all functions parse.""" quoting.parse_function(func)
def visit_Call(self, node): if not self.target: return node func = anno.getanno(node, 'func') if func in tangents.UNIMPLEMENTED_TANGENTS: raise errors.ForwardNotImplementedError(func) if func == tracing.Traceable: raise NotImplementedError('Tracing of %s is not enabled in forward mode' % quoting.unquote(node)) if func not in tangents.tangents: try: quoting.parse_function(func) except: raise ValueError('No tangent found for %s, and could not get source.' % func.__name__) # z = f(x,y) -> d[z],z = df(x,y,dx=dx,dy=dy) active_args = tuple(i for i, arg in enumerate(node.args) if isinstance(arg, gast.Name)) # TODO: Stack arguments are currently not considered # active, but for forward-mode applied to call trees, # they have to be. When we figure out how to update activity # analysis to do the right thing, we'll want to add the extra check: # `and arg.id in self.active_variables` # TODO: Duplicate of code in reverse_ad. already_counted = False for f, a in self.required: if f.__name__ == func.__name__ and set(a) == set(active_args): already_counted = True break if not already_counted: self.required.append((func, active_args)) fn_name = naming.tangent_name(func, active_args) orig_args = quoting.parse_function(func).body[0].args tangent_keywords = [] for i in active_args: grad_node = create.create_grad(node.args[i], self.namer, tangent=True) arg_grad_node = create.create_grad( orig_args.args[i], self.namer, tangent=True) grad_node.ctx = gast.Load() tangent_keywords.append( gast.keyword(arg=arg_grad_node.id, value=grad_node)) # Update the original call rhs = gast.Call( func=gast.Name(id=fn_name, ctx=gast.Load(), annotation=None), args=node.args, keywords=tangent_keywords + node.keywords) # Set self.value to False to trigger whole primal replacement self.value = False return [rhs] template_ = tangents.tangents[func] # Match the function call to the template sig = funcsigs.signature(template_) sig = sig.replace(parameters=list(sig.parameters.values())[1:]) kwargs = dict((keyword.arg, keyword.value) for keyword in node.keywords) bound_args = sig.bind(*node.args, **kwargs) bound_args.apply_defaults() # If any keyword arguments weren't passed, we fill them using the # defaults of the original function if grads.DEFAULT in bound_args.arguments.values(): # Build a mapping from names to defaults args = quoting.parse_function(func).body[0].args defaults = {} for arg, default in zip(*map(reversed, [args.args, args.defaults])): defaults[arg.id] = default for arg, default in zip(args.kwonlyargs, args.kw_defaults): if default is not None: defaults[arg.id] = default for name, value in bound_args.arguments.items(): if value is grads.DEFAULT: bound_args.arguments[name] = defaults[name] # Let's fill in the template. The first argument is the output, which # was stored in a temporary variable output_name = six.get_function_code(template_).co_varnames[0] arg_replacements = {output_name: self.tmp_node} arg_replacements.update(bound_args.arguments) # If the template uses *args, then we pack the corresponding inputs flags = six.get_function_code(template_).co_flags if flags & inspect.CO_VARARGS: to_pack = node.args[six.get_function_code(template_).co_argcount - 1:] vararg_name = six.get_function_code(template_).co_varnames[-1] target = gast.Name(annotation=None, id=vararg_name, ctx=gast.Store()) value = gast.Tuple(elts=to_pack, ctx=gast.Load()) # And we fill in the packed tuple into the template arg_replacements[six.get_function_code(template_).co_varnames[ -1]] = target tangent_node = template.replace( template_, replace_grad=template.Replace.TANGENT, namer=self.namer, **arg_replacements) # If the template uses the answer in the RHS of the tangent, # we need to make sure that the regular answer is replaced # with self.tmp_node, but that the gradient is not. We have # to be extra careful for statements like a = exp(a), because # both the target and RHS variables have the same name. tmp_grad_node = create.create_grad(self.tmp_node, self.namer, tangent=True) tmp_grad_name = tmp_grad_node.id ans_grad_node = create.create_grad(self.target, self.namer, tangent=True) for _node in tangent_node: for succ in gast.walk(_node): if isinstance(succ, gast.Name) and succ.id == tmp_grad_name: succ.id = ans_grad_node.id if flags & inspect.CO_VARARGS: # If the template packs arguments, then we have to unpack the # derivatives afterwards # We also have to update the replacements tuple then dto_pack = [ create.create_temp_grad(arg, self.namer, True) for arg in to_pack ] value = create.create_grad(target, self.namer, tangent=True) target = gast.Tuple(elts=dto_pack, ctx=gast.Store()) # Stack pops have to be special-cased, we have # to set the 'push' attribute, so we know that if we # remove this pop, we have to remove the equivalent push. # NOTE: this only works if we're doing forward-over-reverse, # where reverse is applied in joint mode, with no call tree. # Otherwise, the pushes and pops won't be matched within a single # function call. if func == tangent.pop: if len(self.metastack): anno.setanno(tangent_node[0], 'push', self.metastack.pop()) else: anno.setanno(tangent_node[0], 'push', None) return tangent_node
def anf_function(f, globals_=None): m = gast.gast_to_ast(anf.anf(quoting.parse_function(f))) m = gast.fix_missing_locations(m) exec(compile(m, '<string>', 'exec'), globals_) return f