コード例 #1
0
ファイル: dsl.py プロジェクト: zhengzhuang3/probability
    def call(self, function, vars_in, vars_out=None):
        """Registers a function call instruction.

    Example:
    ```
    ab = dsl.ProgramBuilder()

    # Define a function
    with ab.function(...) as func:
      ...
      # Call it (recursively)
      ab.var.thing = ab.call(func, ...)
      ...
    ```

    Args:
      function: The `instructions.Function` object representing the function to
        call.
      vars_in: Python strings giving the variables to pass in as inputs.
      vars_out: A pattern of Python strings, giving the auto-batched variable(s)
        to which to write the result of the call.  Defaults to the empty list.

    Raises:
      ValueError: If the call references undefined auto-batched variables.

    Returns:
      op: An `instructions.FunctionCallOp` representing the call.  If one
        subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`,
        that local gets added to the list of output variables.
    """
        for var in vars_in:
            if var not in self._var_defs:
                raise ValueError(
                    'Referencing undefined variable {}.'.format(var))
        self._prepare_for_instruction()
        if vars_out is None:
            vars_out = []
        call = inst.FunctionCallOp(function, _str_list(vars_in),
                                   inst.pattern_map(str, vars_out))
        self._blocks[-1].instructions.append(call)
        for var in inst.pattern_traverse(vars_out):
            self._mark_defined(var)
        return call
コード例 #2
0
def fibonacci_function_calls(include_types=True, dtype=np.int64):
    """The Fibonacci program again, but with `instructions.FunctionCallOp`.

  Computes fib(n): fib(0) = fib(1) = 1.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes fib(n).
  """
    enter_fib = instructions.Block(name="enter_fib")
    recur = instructions.Block(name="recur")
    finish = instructions.Block(name="finish")

    fibonacci_type = lambda types: types[0]
    fibonacci_func = instructions.Function(None, ["n"],
                                           "ans",
                                           fibonacci_type,
                                           name="fibonacci")
    # pylint: disable=bad-whitespace
    # Definition of fibonacci function
    enter_fib.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n > 1),  # cond = n > 1
        instructions.BranchOp("cond", recur, finish),  # if cond
    ])
    recur.assign_instructions([
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.FunctionCallOp(fibonacci_func, ["nm1"],
                                    "fibm1"),  #   fibm1 = fibonacci(nm1)
        instructions.prim_op(["n"], "nm2", lambda n: n - 2),  #   nm2 = n - 2
        instructions.FunctionCallOp(fibonacci_func, ["nm2"],
                                    "fibm2"),  #   fibm2 = fibonacci(nm2)
        instructions.prim_op(["fibm1", "fibm2"], "ans",
                             lambda x, y: x + y),  #   ans = fibm1 + fibm2
        instructions.halt_op(),  #   return ans
    ])
    finish.assign_instructions([  # else:
        instructions.prim_op([], "ans", lambda: 1),  #   ans = 1
        instructions.halt_op(),  #   return ans
    ])
    fibonacci_blocks = [enter_fib, recur, finish]
    fibonacci_func.graph = instructions.ControlFlowGraph(fibonacci_blocks)

    fibonacci_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(fibonacci_func, ["n1"], "ans"),
        ],
                           instructions.halt_op(),
                           name="main_entry"),
    ]

    # pylint: disable=bad-whitespace
    fibonacci_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "fibm1": instructions.single_type(dtype, ()),
        "nm2": instructions.single_type(dtype, ()),
        "fibm2": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(dtype, ()),
    }
    if not include_types:
        _strip_types(fibonacci_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(fibonacci_main_blocks), [fibonacci_func],
        fibonacci_vars, ["n1"], "ans")
コード例 #3
0
def is_even_function_calls(include_types=True, dtype=np.int64):
    """The is-even program, via "even-odd" recursion.

  Computes True if the input is even, False if the input is odd, by a pair of
  mutually recursive functions is_even and is_odd, which return True and False
  respectively for <1-valued inputs.

  Tests out mutual recursion.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes is_even(n).
  """
    def pred_type(t):
        return instructions.TensorType(np.bool, t[0].shape)

    # Forward declaration of is_odd.
    is_odd_func = instructions.Function(None, ["n"], "ans", pred_type)

    enter_is_even = instructions.Block()
    finish_is_even = instructions.Block()
    recur_is_even = instructions.Block()
    is_even_func = instructions.Function(None, ["n"], "ans", pred_type)
    # pylint: disable=bad-whitespace
    # Definition of is_even function
    enter_is_even.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_even,
                              recur_is_even),  # if cond
    ])
    finish_is_even.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: True),  #   ans = True
        instructions.halt_op(),  #   return ans
    ])
    recur_is_even.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_odd_func, ["nm1"],
                                    "ans"),  #   ans = is_odd(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_even_blocks = [enter_is_even, finish_is_even, recur_is_even]
    is_even_func.graph = instructions.ControlFlowGraph(is_even_blocks)

    enter_is_odd = instructions.Block()
    finish_is_odd = instructions.Block()
    recur_is_odd = instructions.Block()
    # pylint: disable=bad-whitespace
    # Definition of is_odd function
    enter_is_odd.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_odd, recur_is_odd),  # if cond
    ])
    finish_is_odd.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: False),  #   ans = False
        instructions.halt_op(),  #   return ans
    ])
    recur_is_odd.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_even_func, ["nm1"],
                                    "ans"),  #   ans = is_even(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_odd_blocks = [enter_is_odd, finish_is_odd, recur_is_odd]
    is_odd_func.graph = instructions.ControlFlowGraph(is_odd_blocks)

    is_even_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(is_even_func, ["n1"], "ans"),
        ], instructions.halt_op()),
    ]
    # pylint: disable=bad-whitespace
    is_even_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(np.bool, ()),
    }
    if not include_types:
        _strip_types(is_even_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(is_even_main_blocks),
        [is_even_func, is_odd_func], is_even_vars, ["n1"], "ans")