示例#1
0
def single_if_program():
    """Single if program: 'if (input > 1) ans = 2; else ans = 0; return ans;'.

  Returns:
    program: `instructions.Program` with a simple conditional.
  """
    entry = instructions.Block()
    then_ = instructions.Block()
    else_ = instructions.Block()
    entry.assign_instructions([
        instructions.PrimOp(["input"], "cond", lambda n: n > 1),
        instructions.BranchOp("cond", then_, else_),
    ])
    then_.assign_instructions([
        instructions.PrimOp([], "answer", lambda: 2),
        instructions.halt_op(),
    ])
    else_.assign_instructions([
        instructions.PrimOp([], "answer", lambda: 0),
        instructions.halt_op(),
    ])

    single_if_blocks = [entry, then_, else_]
    # pylint: disable=bad-whitespace
    single_if_vars = {
        "input": instructions.single_type(np.int64, ()),
        "cond": instructions.single_type(np.bool, ()),
        "answer": instructions.single_type(np.int64, ()),
    }

    return instructions.Program(
        instructions.ControlFlowGraph(single_if_blocks), [], single_if_vars,
        ["input"], "answer")
示例#2
0
def shape_sequence_program(shape_sequence):
    """Program that writes into `answer` zeros having a sequence of shapes.

  This enables us to test that the final inferred shape is the broadcast of all
  intermediate shapes.

  Args:
    shape_sequence: The sequence of intermediate shapes.

  Returns:
    program: `instructions.Program` which returns an arbitrary value.
  """
    block_ops = []

    def op(shape, ans):
        return np.zeros(shape, dtype=np.array(ans).dtype),

    for shape in shape_sequence:
        # We use a partial instead of a lambda in order to capture a copy of shape.
        block_ops.append(
            instructions.prim_op(['ans'], ['ans'],
                                 functools.partial(op, shape)))
    shape_seq_block = instructions.Block(block_ops, instructions.halt_op())
    shape_seq_vars = {
        'ans': instructions.Type(None),
        instructions.pc_var: instructions.single_type(np.int64, ()),
    }
    return instructions.Program(
        instructions.ControlFlowGraph([shape_seq_block]), [], shape_seq_vars,
        ['ans'], ['ans'])
示例#3
0
def synthetic_pattern_variable_program(include_types=True):
    """A program that tests product types.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.

  Returns:
    program: `instructions.Program`.
  """
    block = instructions.Block([
        instructions.PrimOp(["inp"], "many", lambda x: (x + 1,
                                                        (x + 2, x + 3))),
        instructions.PrimOp(["many"], ["one", "two"], lambda x: x),
    ], instructions.halt_op())

    leaf = instructions.TensorType(np.int64, ())
    the_vars = {
        "inp": instructions.Type(leaf),
        "many": instructions.Type((leaf, (leaf, leaf))),
        "one": instructions.Type(leaf),
        "two": instructions.Type((leaf, leaf)),
    }

    if not include_types:
        _strip_types(the_vars)
    return instructions.Program(instructions.ControlFlowGraph([block]), [],
                                the_vars, ["inp"], "two")
示例#4
0
  def end_block_with_tail_call(self, target):
    """End the current block with a jump to the given block.

    The terminator of the current block becomes a `GotoOp` to the target.
    No new block is created (as it would be in `split_block`), because
    by assumption there are no additional instructions to return to.

    Args:
      target: The block to jump to.
    """
    self.cur_block().terminator = inst.GotoOp(target)
    # Insurance against having only 2 switch arms, which triggers a bug in
    # tensorflow/compiler/jit/deadness_analysis.
    self.append_block(inst.Block(instructions=[], terminator=inst.halt_op()))
示例#5
0
def constant_program():
    """Constant program: 'ans=1; ans=2; return ans;'.

  Returns:
    program: `instructions.Program` which returns a constant value.
  """
    constant_block = instructions.Block([
        instructions.PrimOp([], "answer", lambda: 1),
        instructions.PrimOp([], "answer", lambda: 2),
    ], instructions.halt_op())

    constant_vars = {
        "answer": instructions.single_type(np.int64, ()),
    }

    return instructions.Program(
        instructions.ControlFlowGraph([constant_block]), [], constant_vars,
        ["answer"], "answer")
示例#6
0
    def return_(self, vars_out):
        """Records a function return instruction.

    Example:
    ```python
    ab = dsl.ProgramBuilder()

    with ab.function(...) as f:
      ...
      ab.var.result = ...
      ab.return_(ab.var.result)
    ```

    A `return_` command must occur at the top level of the function definition
    (not inside any `if_`s), and must be the last statement therein.  You can
    always achieve this by assigning to a dedicated variable for the answer
    where you would otherwise return (and massaging your control flow).

    Args:
      vars_out: Pattern of Python strings giving the auto-batched variables to
        return.

    Raises:
      ValueError: If invoked more than once in a function body, or if trying to
        return variables that have not been written to.
    """
        # Assume the return_ call is at the top level, and the last statement in the
        # body.  If return_ is nested, the terminator may be overwritten
        # incorrectly.  If return_ is followed by something else, extra instructions
        # may get inserted before the return (becaue return_ doesn't set up a Block
        # to catch them).
        self._prepare_for_instruction()
        for var in inst.pattern_traverse(vars_out):
            if var not in self._var_defs:
                raise ValueError(
                    'Returning undefined variable {}.'.format(var))
        if self._functions[-1].vars_out:
            raise ValueError(
                'Function body must have exactly one return_ statement')
        self._functions[-1].vars_out = inst.pattern_map(str, vars_out)
        self._blocks[-1].terminator = inst.halt_op()
示例#7
0
def synthetic_pattern_program():
    """A program that tests pattern matching of `PrimOp` outputs.

  Returns:
    program: `instructions.Program`.
  """
    block = instructions.Block([
        instructions.PrimOp([], ("one", ("five", "three")), lambda: (1,
                                                                     (2, 3))),
        instructions.PrimOp([], (("four", "five"), "six"), lambda:
                            ((4, 5), 6)),
    ], instructions.halt_op())

    the_vars = {
        "one": instructions.single_type(np.int64, ()),
        "three": instructions.single_type(np.int64, ()),
        "four": instructions.single_type(np.int64, ()),
        "five": instructions.single_type(np.int64, ()),
        "six": instructions.single_type(np.int64, ()),
    }

    return instructions.Program(instructions.ControlFlowGraph([block]), [],
                                the_vars, [],
                                (("one", "three"), "four", ("five", "six")))
示例#8
0
def fibonacci_function_calls(include_types=True, dtype=np.int64):
    """The Fibonacci program again, but with `instructions.FunctionCallOp`.

  Computes fib(n): fib(0) = fib(1) = 1.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes fib(n).
  """
    enter_fib = instructions.Block(name="enter_fib")
    recur = instructions.Block(name="recur")
    finish = instructions.Block(name="finish")

    fibonacci_type = lambda types: types[0]
    fibonacci_func = instructions.Function(None, ["n"],
                                           "ans",
                                           fibonacci_type,
                                           name="fibonacci")
    # pylint: disable=bad-whitespace
    # Definition of fibonacci function
    enter_fib.assign_instructions([
        instructions.PrimOp(["n"], "cond", lambda n: n > 1),  # cond = n > 1
        instructions.BranchOp("cond", recur, finish),  # if cond
    ])
    recur.assign_instructions([
        instructions.PrimOp(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.FunctionCallOp(fibonacci_func, ["nm1"],
                                    "fibm1"),  #   fibm1 = fibonacci(nm1)
        instructions.PrimOp(["n"], "nm2", lambda n: n - 2),  #   nm2 = n - 2
        instructions.FunctionCallOp(fibonacci_func, ["nm2"],
                                    "fibm2"),  #   fibm2 = fibonacci(nm2)
        instructions.PrimOp(["fibm1", "fibm2"], "ans",
                            lambda x, y: x + y),  #   ans = fibm1 + fibm2
        instructions.halt_op(),  #   return ans
    ])
    finish.assign_instructions([  # else:
        instructions.PrimOp([], "ans", lambda: 1),  #   ans = 1
        instructions.halt_op(),  #   return ans
    ])
    fibonacci_blocks = [enter_fib, recur, finish]
    fibonacci_func.graph = instructions.ControlFlowGraph(fibonacci_blocks)

    fibonacci_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(fibonacci_func, ["n1"], "ans"),
        ],
                           instructions.halt_op(),
                           name="main_entry"),
    ]

    # pylint: disable=bad-whitespace
    fibonacci_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "fibm1": instructions.single_type(dtype, ()),
        "nm2": instructions.single_type(dtype, ()),
        "fibm2": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(dtype, ()),
    }
    if not include_types:
        _strip_types(fibonacci_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(fibonacci_main_blocks), [fibonacci_func],
        fibonacci_vars, ["n1"], "ans")
示例#9
0
def is_even_function_calls(include_types=True, dtype=np.int64):
    """The is-even program, via "even-odd" recursion.

  Computes True if the input is even, False if the input is odd, by a pair of
  mutually recursive functions is_even and is_odd, which return True and False
  respectively for <1-valued inputs.

  Tests out mutual recursion.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes is_even(n).
  """
    def pred_type(t):
        return instructions.TensorType(np.bool, t[0].shape)

    # Forward declaration of is_odd.
    is_odd_func = instructions.Function(None, ["n"], "ans", pred_type)

    enter_is_even = instructions.Block()
    finish_is_even = instructions.Block()
    recur_is_even = instructions.Block()
    is_even_func = instructions.Function(None, ["n"], "ans", pred_type)
    # pylint: disable=bad-whitespace
    # Definition of is_even function
    enter_is_even.assign_instructions([
        instructions.PrimOp(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_even,
                              recur_is_even),  # if cond
    ])
    finish_is_even.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.PrimOp([], "ans", lambda: True),  #   ans = True
        instructions.halt_op(),  #   return ans
    ])
    recur_is_even.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.PrimOp(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_odd_func, ["nm1"],
                                    "ans"),  #   ans = is_odd(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_even_blocks = [enter_is_even, finish_is_even, recur_is_even]
    is_even_func.graph = instructions.ControlFlowGraph(is_even_blocks)

    enter_is_odd = instructions.Block()
    finish_is_odd = instructions.Block()
    recur_is_odd = instructions.Block()
    # pylint: disable=bad-whitespace
    # Definition of is_odd function
    enter_is_odd.assign_instructions([
        instructions.PrimOp(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_odd, recur_is_odd),  # if cond
    ])
    finish_is_odd.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.PrimOp([], "ans", lambda: False),  #   ans = False
        instructions.halt_op(),  #   return ans
    ])
    recur_is_odd.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.PrimOp(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_even_func, ["nm1"],
                                    "ans"),  #   ans = is_even(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_odd_blocks = [enter_is_odd, finish_is_odd, recur_is_odd]
    is_odd_func.graph = instructions.ControlFlowGraph(is_odd_blocks)

    is_even_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(is_even_func, ["n1"], "ans"),
        ], instructions.halt_op()),
    ]
    # pylint: disable=bad-whitespace
    is_even_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(np.bool, ()),
    }
    if not include_types:
        _strip_types(is_even_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(is_even_main_blocks),
        [is_even_func, is_odd_func], is_even_vars, ["n1"], "ans")