예제 #1
0
    def declare_function(self, name=None, type_inference=None):
        """Forward-declares a function to be defined later with `define_function`.

    This useful for defining mutually recursive functions:
    ```python
    ab = dsl.ProgramBuilder()

    foo = ab.declare_function(...)

    with ab.function(...) as bar:
      ...
      ab.call(foo)

    with ab.define_function(foo):
      ...
      ab.call(bar)
    ```

    It is an error to call but never define a declared function.

    Args:
      name: Optional string naming this function when the program is printed.
      type_inference: A Python callable giving the type signature of the
        function being defined.  See `function`.

    Returns:
      function: An `instructions.Function` object representing the function
        being declared.  It can be passed to `call` to call it, and to
        `define_function` to define it.
    """
        return inst.Function(graph=None,
                             vars_in=[],
                             vars_out=[],
                             type_inference=type_inference,
                             name=name)
def fibonacci_function_calls(include_types=True, dtype=np.int64):
    """The Fibonacci program again, but with `instructions.FunctionCallOp`.

  Computes fib(n): fib(0) = fib(1) = 1.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes fib(n).
  """
    enter_fib = instructions.Block(name="enter_fib")
    recur = instructions.Block(name="recur")
    finish = instructions.Block(name="finish")

    fibonacci_type = lambda types: types[0]
    fibonacci_func = instructions.Function(None, ["n"],
                                           "ans",
                                           fibonacci_type,
                                           name="fibonacci")
    # pylint: disable=bad-whitespace
    # Definition of fibonacci function
    enter_fib.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n > 1),  # cond = n > 1
        instructions.BranchOp("cond", recur, finish),  # if cond
    ])
    recur.assign_instructions([
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.FunctionCallOp(fibonacci_func, ["nm1"],
                                    "fibm1"),  #   fibm1 = fibonacci(nm1)
        instructions.prim_op(["n"], "nm2", lambda n: n - 2),  #   nm2 = n - 2
        instructions.FunctionCallOp(fibonacci_func, ["nm2"],
                                    "fibm2"),  #   fibm2 = fibonacci(nm2)
        instructions.prim_op(["fibm1", "fibm2"], "ans",
                             lambda x, y: x + y),  #   ans = fibm1 + fibm2
        instructions.halt_op(),  #   return ans
    ])
    finish.assign_instructions([  # else:
        instructions.prim_op([], "ans", lambda: 1),  #   ans = 1
        instructions.halt_op(),  #   return ans
    ])
    fibonacci_blocks = [enter_fib, recur, finish]
    fibonacci_func.graph = instructions.ControlFlowGraph(fibonacci_blocks)

    fibonacci_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(fibonacci_func, ["n1"], "ans"),
        ],
                           instructions.halt_op(),
                           name="main_entry"),
    ]

    # pylint: disable=bad-whitespace
    fibonacci_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "fibm1": instructions.single_type(dtype, ()),
        "nm2": instructions.single_type(dtype, ()),
        "fibm2": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(dtype, ()),
    }
    if not include_types:
        _strip_types(fibonacci_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(fibonacci_main_blocks), [fibonacci_func],
        fibonacci_vars, ["n1"], "ans")
def is_even_function_calls(include_types=True, dtype=np.int64):
    """The is-even program, via "even-odd" recursion.

  Computes True if the input is even, False if the input is odd, by a pair of
  mutually recursive functions is_even and is_odd, which return True and False
  respectively for <1-valued inputs.

  Tests out mutual recursion.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes is_even(n).
  """
    def pred_type(t):
        return instructions.TensorType(np.bool, t[0].shape)

    # Forward declaration of is_odd.
    is_odd_func = instructions.Function(None, ["n"], "ans", pred_type)

    enter_is_even = instructions.Block()
    finish_is_even = instructions.Block()
    recur_is_even = instructions.Block()
    is_even_func = instructions.Function(None, ["n"], "ans", pred_type)
    # pylint: disable=bad-whitespace
    # Definition of is_even function
    enter_is_even.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_even,
                              recur_is_even),  # if cond
    ])
    finish_is_even.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: True),  #   ans = True
        instructions.halt_op(),  #   return ans
    ])
    recur_is_even.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_odd_func, ["nm1"],
                                    "ans"),  #   ans = is_odd(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_even_blocks = [enter_is_even, finish_is_even, recur_is_even]
    is_even_func.graph = instructions.ControlFlowGraph(is_even_blocks)

    enter_is_odd = instructions.Block()
    finish_is_odd = instructions.Block()
    recur_is_odd = instructions.Block()
    # pylint: disable=bad-whitespace
    # Definition of is_odd function
    enter_is_odd.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_odd, recur_is_odd),  # if cond
    ])
    finish_is_odd.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: False),  #   ans = False
        instructions.halt_op(),  #   return ans
    ])
    recur_is_odd.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_even_func, ["nm1"],
                                    "ans"),  #   ans = is_even(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_odd_blocks = [enter_is_odd, finish_is_odd, recur_is_odd]
    is_odd_func.graph = instructions.ControlFlowGraph(is_odd_blocks)

    is_even_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(is_even_func, ["n1"], "ans"),
        ], instructions.halt_op()),
    ]
    # pylint: disable=bad-whitespace
    is_even_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(np.bool, ()),
    }
    if not include_types:
        _strip_types(is_even_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(is_even_main_blocks),
        [is_even_func, is_odd_func], is_even_vars, ["n1"], "ans")