def declare_function(self, name=None, type_inference=None): """Forward-declares a function to be defined later with `define_function`. This useful for defining mutually recursive functions: ```python ab = dsl.ProgramBuilder() foo = ab.declare_function(...) with ab.function(...) as bar: ... ab.call(foo) with ab.define_function(foo): ... ab.call(bar) ``` It is an error to call but never define a declared function. Args: name: Optional string naming this function when the program is printed. type_inference: A Python callable giving the type signature of the function being defined. See `function`. Returns: function: An `instructions.Function` object representing the function being declared. It can be passed to `call` to call it, and to `define_function` to define it. """ return inst.Function(graph=None, vars_in=[], vars_out=[], type_inference=type_inference, name=name)
def fibonacci_function_calls(include_types=True, dtype=np.int64): """The Fibonacci program again, but with `instructions.FunctionCallOp`. Computes fib(n): fib(0) = fib(1) = 1. Args: include_types: If False, we omit types on the variables, requiring a type inference pass. dtype: The dtype to use for `n`-like internal state variables. Returns: program: Full-powered `instructions.Program` that computes fib(n). """ enter_fib = instructions.Block(name="enter_fib") recur = instructions.Block(name="recur") finish = instructions.Block(name="finish") fibonacci_type = lambda types: types[0] fibonacci_func = instructions.Function(None, ["n"], "ans", fibonacci_type, name="fibonacci") # pylint: disable=bad-whitespace # Definition of fibonacci function enter_fib.assign_instructions([ instructions.prim_op(["n"], "cond", lambda n: n > 1), # cond = n > 1 instructions.BranchOp("cond", recur, finish), # if cond ]) recur.assign_instructions([ instructions.prim_op(["n"], "nm1", lambda n: n - 1), # nm1 = n - 1 instructions.FunctionCallOp(fibonacci_func, ["nm1"], "fibm1"), # fibm1 = fibonacci(nm1) instructions.prim_op(["n"], "nm2", lambda n: n - 2), # nm2 = n - 2 instructions.FunctionCallOp(fibonacci_func, ["nm2"], "fibm2"), # fibm2 = fibonacci(nm2) instructions.prim_op(["fibm1", "fibm2"], "ans", lambda x, y: x + y), # ans = fibm1 + fibm2 instructions.halt_op(), # return ans ]) finish.assign_instructions([ # else: instructions.prim_op([], "ans", lambda: 1), # ans = 1 instructions.halt_op(), # return ans ]) fibonacci_blocks = [enter_fib, recur, finish] fibonacci_func.graph = instructions.ControlFlowGraph(fibonacci_blocks) fibonacci_main_blocks = [ instructions.Block([ instructions.FunctionCallOp(fibonacci_func, ["n1"], "ans"), ], instructions.halt_op(), name="main_entry"), ] # pylint: disable=bad-whitespace fibonacci_vars = { "n": instructions.single_type(dtype, ()), "n1": instructions.single_type(dtype, ()), "cond": instructions.single_type(np.bool, ()), "nm1": instructions.single_type(dtype, ()), "fibm1": instructions.single_type(dtype, ()), "nm2": instructions.single_type(dtype, ()), "fibm2": instructions.single_type(dtype, ()), "ans": instructions.single_type(dtype, ()), } if not include_types: _strip_types(fibonacci_vars) return instructions.Program( instructions.ControlFlowGraph(fibonacci_main_blocks), [fibonacci_func], fibonacci_vars, ["n1"], "ans")
def is_even_function_calls(include_types=True, dtype=np.int64): """The is-even program, via "even-odd" recursion. Computes True if the input is even, False if the input is odd, by a pair of mutually recursive functions is_even and is_odd, which return True and False respectively for <1-valued inputs. Tests out mutual recursion. Args: include_types: If False, we omit types on the variables, requiring a type inference pass. dtype: The dtype to use for `n`-like internal state variables. Returns: program: Full-powered `instructions.Program` that computes is_even(n). """ def pred_type(t): return instructions.TensorType(np.bool, t[0].shape) # Forward declaration of is_odd. is_odd_func = instructions.Function(None, ["n"], "ans", pred_type) enter_is_even = instructions.Block() finish_is_even = instructions.Block() recur_is_even = instructions.Block() is_even_func = instructions.Function(None, ["n"], "ans", pred_type) # pylint: disable=bad-whitespace # Definition of is_even function enter_is_even.assign_instructions([ instructions.prim_op(["n"], "cond", lambda n: n < 1), # cond = n < 1 instructions.BranchOp("cond", finish_is_even, recur_is_even), # if cond ]) finish_is_even.assign_instructions([ instructions.PopOp(["n", "cond"]), # done with n, cond instructions.prim_op([], "ans", lambda: True), # ans = True instructions.halt_op(), # return ans ]) recur_is_even.assign_instructions([ # else instructions.PopOp(["cond"]), # done with cond now instructions.prim_op(["n"], "nm1", lambda n: n - 1), # nm1 = n - 1 instructions.PopOp(["n"]), # done with n instructions.FunctionCallOp(is_odd_func, ["nm1"], "ans"), # ans = is_odd(nm1) instructions.PopOp(["nm1"]), # done with nm1 instructions.halt_op(), # return ans ]) is_even_blocks = [enter_is_even, finish_is_even, recur_is_even] is_even_func.graph = instructions.ControlFlowGraph(is_even_blocks) enter_is_odd = instructions.Block() finish_is_odd = instructions.Block() recur_is_odd = instructions.Block() # pylint: disable=bad-whitespace # Definition of is_odd function enter_is_odd.assign_instructions([ instructions.prim_op(["n"], "cond", lambda n: n < 1), # cond = n < 1 instructions.BranchOp("cond", finish_is_odd, recur_is_odd), # if cond ]) finish_is_odd.assign_instructions([ instructions.PopOp(["n", "cond"]), # done with n, cond instructions.prim_op([], "ans", lambda: False), # ans = False instructions.halt_op(), # return ans ]) recur_is_odd.assign_instructions([ # else instructions.PopOp(["cond"]), # done with cond now instructions.prim_op(["n"], "nm1", lambda n: n - 1), # nm1 = n - 1 instructions.PopOp(["n"]), # done with n instructions.FunctionCallOp(is_even_func, ["nm1"], "ans"), # ans = is_even(nm1) instructions.PopOp(["nm1"]), # done with nm1 instructions.halt_op(), # return ans ]) is_odd_blocks = [enter_is_odd, finish_is_odd, recur_is_odd] is_odd_func.graph = instructions.ControlFlowGraph(is_odd_blocks) is_even_main_blocks = [ instructions.Block([ instructions.FunctionCallOp(is_even_func, ["n1"], "ans"), ], instructions.halt_op()), ] # pylint: disable=bad-whitespace is_even_vars = { "n": instructions.single_type(dtype, ()), "n1": instructions.single_type(dtype, ()), "cond": instructions.single_type(np.bool, ()), "nm1": instructions.single_type(dtype, ()), "ans": instructions.single_type(np.bool, ()), } if not include_types: _strip_types(is_even_vars) return instructions.Program( instructions.ControlFlowGraph(is_even_main_blocks), [is_even_func, is_odd_func], is_even_vars, ["n1"], "ans")