def fuse_pop_push(program):
    """Fuses pop+push sequences in the given `Program`.

  A stack pop followed by a stack push (with no intervening read) is equivalent
  to just updating the top of the stack.  The latter is more efficient for FULL
  variables, because it just updates the cache for the top, and avoids gathering
  from and scattering to the backing stack Tensor.

  This pass mutates the `ControlFlowGraph` of the input `Program` to convert
  pop+push sequences into updates.  The pass will work despite intervening
  instructions that interact with other Variables, but will not cross basic
  block boundaries.  As a side-effect, the pass moves non-optimized pops to the
  last place in their basic block where they are still sound.  This has no
  effect on the runtime behavior of the program.

  Args:
    program: A lowered `Program` whose pop+push sequences to fuse.  `Block`s in
      the program may be mutated.

  Returns:
    fused: A `Program` with statically redundant pop+push sequences eliminated
      in favor of `PrimOp`s with non-trivial `skip_push_mask` fields.

  Raises:
    ValueError: If the input `Program` has not been lowered (i.e., contains
      `FunctionCallOp`), or is ill-formed (e.g., contains invalid instructions).
  """
    for i in range(program.graph.exit_index()):
        block = program.graph.block(i)
        new_instructions = []
        waiting_for_pop = collections.OrderedDict([])
        for op in block.instructions:
            if isinstance(op, inst.PrimOp):
                skip_push = set()
                for var in inst.pattern_traverse(op.vars_in):
                    if var in waiting_for_pop:
                        new_instructions.append(inst.PopOp([var]))
                        del waiting_for_pop[var]
                for var in inst.pattern_traverse(op.vars_out):
                    if var in waiting_for_pop:
                        skip_push.add(var)
                        del waiting_for_pop[var]
                new_instructions.append(
                    inst.prim_op(op.vars_in, op.vars_out, op.function,
                                 skip_push))
            elif isinstance(op, inst.FunctionCallOp):
                raise ValueError(
                    "pop-push fusion not implemented for pre-lowering")
            elif isinstance(op, inst.PopOp):
                for var in inst.pattern_traverse(op.vars):
                    if var in waiting_for_pop:
                        new_instructions.append(inst.PopOp([var]))
                    else:
                        waiting_for_pop[var] = True
            else:
                raise ValueError("Unrecognized op in pop-push fusion", op)
        if waiting_for_pop:
            new_instructions.append(inst.PopOp(list(waiting_for_pop.keys())))
        block.instructions = new_instructions
    return program
Exemple #2
0
def _function_entry_stack_ops(op, to_pop):
    """Computes a set of stack operations for the entry to a FunctionCallOp.

  The function calling convention is
  - Push the arguments to the formal parameters
  - Pop any now-dead arguments so they're not on the stack during the call
  - Jump to the beginning of the function body

  This can be a little tricky for a self-call, because then the arguments and
  the formal parameters live in the same name space and can collide.  This
  helper does something reasonable, and errors out when it can't.

  Args:
    op: FunctionCallOp instance giving the call to make stack operations for.
    to_pop: Set of names to make sure are popped before entering.

  Returns:
    ops: List of instruction objects that accomplish the goal.
  """
    push_from = []
    push_to = []
    caller_side_vars = inst.pattern_flatten(op.vars_in)
    callee_side_vars = inst.pattern_flatten(op.function.vars_in)
    for caller_side, callee_side in inst.pattern_zip(caller_side_vars,
                                                     callee_side_vars):
        if caller_side == callee_side:
            # This can happen if this is a self-call and we're just passing the
            # variable through to itself
            if caller_side in to_pop:
                # The top of the stack is already correct, and the callee will pop our
                # unneeded value off it for us: skip the push and the pop.
                to_pop = to_pop.difference([caller_side])
            else:
                # The top of the stack is correct, but we need to push it anyway because
                # the callee will (eventually) pop but we still need the current value
                # when the callee returns.
                push_from.append(caller_side)
                push_to.append(callee_side)
        elif callee_side in caller_side_vars:
            # If the graph of transfers turns out not to be a DAG, can't implement it
            # without temporary space; and don't want to bother computing a safe order
            # even if it does.
            # Fixing this is b/135275883.
            msg = ('Cannot lower FunctionCallOp that reuses its own input {}'
                   ' as a formal parameter.')
            raise ValueError(msg.format(caller_side))
        # Checking `elif caller_side in callee_side_vars` is redundant because
        # the callee_side check will trigger on that pair sooner or later.
        else:
            # Ordinary transfer: push now and then pop if needed.  The pop will not
            # interfere with the push because only caller-side variables can possibly
            # be popped.
            assert callee_side not in to_pop
            push_from.append(caller_side)
            push_to.append(callee_side)
    push_inst = inst.push_op(push_from, push_to)
    if to_pop:
        return [push_inst, inst.PopOp(list(to_pop))]
    else:
        return [push_inst]
def pea_nuts_program(latent_shape, choose_depth, step_state):
    """Synthetic program usable for benchmarking VM performance.

  This program is intended to resemble the control flow and scaling
  parameters of the NUTS algorithm, without any of the complexity.
  Hence the name.

  Each batch member looks like:

    state = ... # shape latent_shape

    def recur(depth, state):
      if depth > 1:
        state1 = recur(depth - 1, state)
        state2 = state1 + 1
        state3 = recur(depth - 1, state2)
        ans = state3 + 1
      else:
        ans = step_state(state)  # To simulate NUTS, something heavy
      return ans

    while count > 0:
      count = count - 1
      depth = choose_depth(count)
      state = recur(depth, state)

  Args:
    latent_shape: Python `tuple` of `int` giving the event shape of the
      latent state.
    choose_depth: Python `Tensor -> Tensor` callable.  The input
      `Tensor` will have shape `[batch_size]` (i.e., scalar event
      shape), and give the iteration of the outer while loop the
      thread is in.  The `choose_depth` function must return a `Tensor`
      of shape `[batch_size]` giving the depth, for each thread,
      to which to call `recur` in this iteration.
    step_state: Python `Tensor -> Tensor` callable.  The input and
      output `Tensor`s will have shape `[batch_size] + latent_shape`.
      This function is expected to update the state, and represents
      the "real work" versus which the VM overhead is being measured.

  Returns:
    program: `instructions.Program` that runs the above benchmark.
  """
    entry = instructions.Block()
    top_body = instructions.Block()
    finish_body = instructions.Block()
    enter_recur = instructions.Block()
    recur_body_1 = instructions.Block()
    recur_body_2 = instructions.Block()
    recur_body_3 = instructions.Block()
    recur_base_case = instructions.Block()
    # pylint: disable=bad-whitespace
    entry.assign_instructions([
        instructions.prim_op(["count"], "cond",
                             lambda count: count > 0),  # cond = count > 0
        instructions.BranchOp("cond", top_body,
                              instructions.halt()),  # if cond
    ])
    top_body.assign_instructions([
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["count"], "ctm1",
                             lambda count: count - 1),  #   ctm1 = count - 1
        instructions.PopOp(["count"]),  #   done with count now
        instructions.push_op(["ctm1"], ["count"]),  #   count = ctm1
        instructions.PopOp(["ctm1"]),  #   done with ctm1
        instructions.prim_op(["count"], "depth",
                             choose_depth),  #   depth = choose_depth(count)
        instructions.push_op(
            ["depth", "state"],
            ["depth", "state"]),  #   state = recur(depth, state)
        instructions.PopOp(["depth", "state"]),  #     done with depth, state
        instructions.PushGotoOp(finish_body, enter_recur),
    ])
    finish_body.assign_instructions([
        instructions.push_op(["ans"], ["state"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.GotoOp(entry),  # end of while body
    ])
    # Definition of recur begins here
    enter_recur.assign_instructions([
        instructions.prim_op(["depth"], "cond1",
                             lambda depth: depth > 0),  # cond1 = depth > 0
        instructions.BranchOp("cond1", recur_body_1,
                              recur_base_case),  # if cond1
    ])
    recur_body_1.assign_instructions([
        instructions.PopOp(["cond1"]),  #   done with cond1 now
        instructions.prim_op(["depth"], "dm1",
                             lambda depth: depth - 1),  #   dm1 = depth - 1
        instructions.PopOp(["depth"]),  #   done with depth
        instructions.push_op(
            ["dm1", "state"],
            ["depth", "state"]),  #   state1 = recur(dm1, state)
        instructions.PopOp(["state"]),  #     done with state
        instructions.PushGotoOp(recur_body_2, enter_recur),
    ])
    recur_body_2.assign_instructions([
        instructions.push_op(["ans"], ["state1"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.prim_op(["state1"], "state2",
                             lambda state: state + 1),  #   state2 = state1 + 1
        instructions.PopOp(["state1"]),  #   done with state1
        instructions.push_op(
            ["dm1", "state2"],
            ["depth", "state"]),  #   state3 = recur(dm1, state2)
        instructions.PopOp(["dm1", "state2"]),  #     done with dm1, state2
        instructions.PushGotoOp(recur_body_3, enter_recur),
    ])
    recur_body_3.assign_instructions([
        instructions.push_op(["ans"], ["state3"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.prim_op(["state3"], "ans",
                             lambda state: state + 1),  #   ans = state3 + 1
        instructions.PopOp(["state3"]),  #   done with state3
        instructions.IndirectGotoOp(),  #   return ans
    ])
    recur_base_case.assign_instructions([
        instructions.PopOp(["cond1", "depth"]),  #   done with cond1, depth
        instructions.prim_op(["state"], "ans",
                             step_state),  #   ans = step_state(state)
        instructions.PopOp(["state"]),  #   done with state
        instructions.IndirectGotoOp(),  #   return ans
    ])

    pea_nuts_graph = instructions.ControlFlowGraph([
        entry,
        top_body,
        finish_body,
        enter_recur,
        recur_body_1,
        recur_body_2,
        recur_body_3,
        recur_base_case,
    ])

    # pylint: disable=bad-whitespace
    pea_nuts_vars = {
        "count": instructions.single_type(np.int64, ()),
        "cond": instructions.single_type(np.bool, ()),
        "cond1": instructions.single_type(np.bool, ()),
        "ctm1": instructions.single_type(np.int64, ()),
        "depth": instructions.single_type(np.int64, ()),
        "dm1": instructions.single_type(np.int64, ()),
        "state": instructions.single_type(np.float32, latent_shape),
        "state1": instructions.single_type(np.float32, latent_shape),
        "state2": instructions.single_type(np.float32, latent_shape),
        "state3": instructions.single_type(np.float32, latent_shape),
        "ans": instructions.single_type(np.float32, latent_shape),
    }

    return instructions.Program(pea_nuts_graph, [], pea_nuts_vars,
                                ["count", "state"], "state")
def is_even_function_calls(include_types=True, dtype=np.int64):
    """The is-even program, via "even-odd" recursion.

  Computes True if the input is even, False if the input is odd, by a pair of
  mutually recursive functions is_even and is_odd, which return True and False
  respectively for <1-valued inputs.

  Tests out mutual recursion.

  Args:
    include_types: If False, we omit types on the variables, requiring a type
        inference pass.
    dtype: The dtype to use for `n`-like internal state variables.

  Returns:
    program: Full-powered `instructions.Program` that computes is_even(n).
  """
    def pred_type(t):
        return instructions.TensorType(np.bool, t[0].shape)

    # Forward declaration of is_odd.
    is_odd_func = instructions.Function(None, ["n"], "ans", pred_type)

    enter_is_even = instructions.Block()
    finish_is_even = instructions.Block()
    recur_is_even = instructions.Block()
    is_even_func = instructions.Function(None, ["n"], "ans", pred_type)
    # pylint: disable=bad-whitespace
    # Definition of is_even function
    enter_is_even.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_even,
                              recur_is_even),  # if cond
    ])
    finish_is_even.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: True),  #   ans = True
        instructions.halt_op(),  #   return ans
    ])
    recur_is_even.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_odd_func, ["nm1"],
                                    "ans"),  #   ans = is_odd(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_even_blocks = [enter_is_even, finish_is_even, recur_is_even]
    is_even_func.graph = instructions.ControlFlowGraph(is_even_blocks)

    enter_is_odd = instructions.Block()
    finish_is_odd = instructions.Block()
    recur_is_odd = instructions.Block()
    # pylint: disable=bad-whitespace
    # Definition of is_odd function
    enter_is_odd.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n < 1),  # cond = n < 1
        instructions.BranchOp("cond", finish_is_odd, recur_is_odd),  # if cond
    ])
    finish_is_odd.assign_instructions([
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: False),  #   ans = False
        instructions.halt_op(),  #   return ans
    ])
    recur_is_odd.assign_instructions([  # else
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.PopOp(["n"]),  #   done with n
        instructions.FunctionCallOp(is_even_func, ["nm1"],
                                    "ans"),  #   ans = is_even(nm1)
        instructions.PopOp(["nm1"]),  #   done with nm1
        instructions.halt_op(),  #   return ans
    ])
    is_odd_blocks = [enter_is_odd, finish_is_odd, recur_is_odd]
    is_odd_func.graph = instructions.ControlFlowGraph(is_odd_blocks)

    is_even_main_blocks = [
        instructions.Block([
            instructions.FunctionCallOp(is_even_func, ["n1"], "ans"),
        ], instructions.halt_op()),
    ]
    # pylint: disable=bad-whitespace
    is_even_vars = {
        "n": instructions.single_type(dtype, ()),
        "n1": instructions.single_type(dtype, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(dtype, ()),
        "ans": instructions.single_type(np.bool, ()),
    }
    if not include_types:
        _strip_types(is_even_vars)

    return instructions.Program(
        instructions.ControlFlowGraph(is_even_main_blocks),
        [is_even_func, is_odd_func], is_even_vars, ["n1"], "ans")
def fibonacci_program():
    """More complicated, fibonacci program: computes fib(n): fib(0) = fib(1) = 1.

  Returns:
    program: Full-powered `instructions.Program` that computes fib(n).
  """
    entry = instructions.Block(name="entry")
    enter_fib = instructions.Block(name="enter_fib")
    recur1 = instructions.Block(name="recur1")
    recur2 = instructions.Block(name="recur2")
    recur3 = instructions.Block(name="recur3")
    finish = instructions.Block(name="finish")
    # pylint: disable=bad-whitespace
    entry.assign_instructions([
        instructions.PushGotoOp(instructions.halt(), enter_fib),
    ])
    # Definition of fibonacci function starts here
    enter_fib.assign_instructions([
        instructions.prim_op(["n"], "cond", lambda n: n > 1),  # cond = n > 1
        instructions.BranchOp("cond", recur1, finish),  # if cond
    ])
    recur1.assign_instructions([
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.prim_op(["n"], "nm1", lambda n: n - 1),  #   nm1 = n - 1
        instructions.push_op(["nm1"], ["n"]),  #   fibm1 = fibonacci(nm1)
        instructions.PopOp(["nm1"]),  #     done with nm1
        instructions.PushGotoOp(recur2, enter_fib),
    ])
    recur2.assign_instructions([
        instructions.push_op(["ans"], ["fibm1"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.prim_op(["n"], "nm2", lambda n: n - 2),  #   nm2 = n - 2
        instructions.PopOp(["n"]),  #   done with n
        instructions.push_op(["nm2"], ["n"]),  #   fibm2 = fibonacci(nm2)
        instructions.PopOp(["nm2"]),  #     done with nm2
        instructions.PushGotoOp(recur3, enter_fib),
    ])
    recur3.assign_instructions([
        instructions.push_op(["ans"], ["fibm2"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.prim_op(["fibm1", "fibm2"], "ans",
                             lambda x, y: x + y),  #   ans = fibm1 + fibm2
        instructions.PopOp(["fibm1", "fibm2"]),  #   done with fibm1, fibm2
        instructions.IndirectGotoOp(),  #   return ans
    ])
    finish.assign_instructions([  # else:
        instructions.PopOp(["n", "cond"]),  #   done with n, cond
        instructions.prim_op([], "ans", lambda: 1),  #   ans = 1
        instructions.IndirectGotoOp(),  #   return ans
    ])

    fibonacci_blocks = [entry, enter_fib, recur1, recur2, recur3, finish]

    # pylint: disable=bad-whitespace
    fibonacci_vars = {
        "n": instructions.single_type(np.int64, ()),
        "cond": instructions.single_type(np.bool, ()),
        "nm1": instructions.single_type(np.int64, ()),
        "fibm1": instructions.single_type(np.int64, ()),
        "nm2": instructions.single_type(np.int64, ()),
        "fibm2": instructions.single_type(np.int64, ()),
        "ans": instructions.single_type(np.int64, ()),
    }

    return instructions.Program(
        instructions.ControlFlowGraph(fibonacci_blocks), [], fibonacci_vars,
        ["n"], "ans")
Exemple #6
0
def _lower_function_calls_1(builder,
                            graph,
                            defined_in,
                            live_out,
                            function=True):
    """Lowers one function body, destructively.

  Mutates the given `ControlFlowGraphBuilder`, inserting `Block`s
  representing the new body.  Some of these may be the same as some
  `Block`s in the input `graph`, mutated; others may be newly
  allocated.

  Args:
    builder: `ControlFlowGraphBuilder` constructing the answer.
    graph: The `ControlFlowGraph` to lower.
    defined_in: A Python list of `str`.  The set of variables that
      are defined on entry to this `graph`.
    live_out: A Python list of `str`.  The set of variables that are
      live on exit from this `graph`.
    function: Python `bool`.  If `True` (the default), assume this is
      a `Function` body and convert an "exit" transfer into
      `IndirectGotoOp`; otherwise leave it as (`Program`) "exit".

  Raises:
    ValueError: If an invalid instruction is encountered, if a live
      variable is undefined, if different paths into a `Block` cause
      different sets of variables to be defined, or if trying to lower
      function calls in a program that already has `IndirectGotoOp`
      instructions (they confuse the liveness analysis).
  """
    liveness_map = liveness.liveness_analysis(graph, set(live_out))
    defined_map = _definedness_analysis(graph, defined_in, liveness_map)
    for i in range(graph.exit_index()):
        block = graph.block(i)
        old_instructions = block.instructions
        # Resetting block.instructions here because we will build up the
        # list of new ones in place (via the `builder`).
        block.instructions = []
        builder.append_block(block)
        builder.maybe_add_pop(defined_map[block].defined_into_block,
                              liveness_map[block].live_into_block)
        for op_i, (op, defined_out, live_out) in enumerate(
                zip(old_instructions,
                    defined_map[block].defined_out_instructions,
                    liveness_map[block].live_out_instructions)):
            if isinstance(op, inst.PrimOp):
                for name in inst.pattern_traverse(op.vars_in):
                    if name in inst.pattern_flatten(op.vars_out):
                        # Why not?  Because the stack discipline we are trying to
                        # implement calls for popping variables as soon as they
                        # become dead.  Now, if a PrimOp writes to the same
                        # variable as it reads, the old version of that variable
                        # dies.  Where to put the PopOp?  Before the PrimOp is no
                        # good -- it still needs to be read.  After the PrimOp is
                        # no good either -- it will pop the output, not the input.
                        # Various solutions to this problem are possible, such as
                        # adding a "drop the second-top element of this stack"
                        # instruction, or orchestrating the pushes and pops
                        # directly in the interpreter, but for now the simplest
                        # thing is to just forbid this situation.
                        # Fixing this is b/118884528.
                        msg = 'Cannot lower PrimOp that writes to its own input {}.'
                        raise ValueError(msg.format(name))
                builder.append_instruction(op)
                builder.maybe_add_pop(defined_out, live_out)
            elif isinstance(op, inst.FunctionCallOp):
                names_pushed_here = inst.pattern_flatten(op.vars_out)
                for name in inst.pattern_traverse(op.vars_in):
                    if name in names_pushed_here:
                        # For the same reason as above.
                        # Fixing this is b/118884528.
                        msg = 'Cannot lower FunctionCallOp that writes to its own input {}.'
                        raise ValueError(msg.format(name))
                # The variables that were defined on entry to this instruction (i.e.,
                # not pushed here) but are not live out don't need to remain on their
                # stacks when the callee is entered.
                defined_in = defined_out.difference(names_pushed_here)
                to_pop = defined_in.difference(live_out)
                for new_op in _function_entry_stack_ops(op, to_pop):
                    builder.append_instruction(new_op)
                if (op_i == len(old_instructions) - 1
                        and _optimizable_tail_call(op, builder.cur_block())):
                    builder.end_block_with_tail_call(
                        op.function.graph.block(0))
                    # The check that the tail call is optimizable is equivalent to
                    # checking that the push-pop pair below would do nothing.
                else:
                    builder.split_block(op.function.graph.block(0))
                    builder.append_instruction(
                        # These extra levels of list protect me (I hope) from the
                        # auto-unpacking in the implementation of push_op, in the case of
                        # a function returning exactly one Tensor.
                        inst.push_op([op.function.vars_out], [op.vars_out]))
                    builder.append_instruction(
                        inst.PopOp(inst.pattern_flatten(op.function.vars_out)))
                    # The only way this would actually add a pop is if some name written
                    # by this call was a dummy variable.
                    builder.maybe_add_pop(frozenset(names_pushed_here),
                                          live_out)
            elif isinstance(op, (inst.PopOp)):
                # Presumably, lowering is applied before any `PopOp`s are present.  That
                # said, simply propagating them is sound.  (But see the `PopOp` case in
                # `liveness_analysis`.)
                builder.append_instruction(op)
            else:
                raise ValueError('Invalid instruction in block {}.'.format(op))
        if function:
            builder.maybe_adjust_terminator()
Exemple #7
0
 def maybe_add_pop(self, defined, live):
     poppable = defined.difference(live)
     if poppable:
         self.append_instruction(inst.PopOp(list(poppable)))