def fuse_pop_push(program):
    """Fuses pop+push sequences in the given `Program`.

  A stack pop followed by a stack push (with no intervening read) is equivalent
  to just updating the top of the stack.  The latter is more efficient for FULL
  variables, because it just updates the cache for the top, and avoids gathering
  from and scattering to the backing stack Tensor.

  This pass mutates the `ControlFlowGraph` of the input `Program` to convert
  pop+push sequences into updates.  The pass will work despite intervening
  instructions that interact with other Variables, but will not cross basic
  block boundaries.  As a side-effect, the pass moves non-optimized pops to the
  last place in their basic block where they are still sound.  This has no
  effect on the runtime behavior of the program.

  Args:
    program: A lowered `Program` whose pop+push sequences to fuse.  `Block`s in
      the program may be mutated.

  Returns:
    fused: A `Program` with statically redundant pop+push sequences eliminated
      in favor of `PrimOp`s with non-trivial `skip_push_mask` fields.

  Raises:
    ValueError: If the input `Program` has not been lowered (i.e., contains
      `FunctionCallOp`), or is ill-formed (e.g., contains invalid instructions).
  """
    for i in range(program.graph.exit_index()):
        block = program.graph.block(i)
        new_instructions = []
        waiting_for_pop = collections.OrderedDict([])
        for op in block.instructions:
            if isinstance(op, inst.PrimOp):
                skip_push = set()
                for var in inst.pattern_traverse(op.vars_in):
                    if var in waiting_for_pop:
                        new_instructions.append(inst.PopOp([var]))
                        del waiting_for_pop[var]
                for var in inst.pattern_traverse(op.vars_out):
                    if var in waiting_for_pop:
                        skip_push.add(var)
                        del waiting_for_pop[var]
                new_instructions.append(
                    inst.prim_op(op.vars_in, op.vars_out, op.function,
                                 skip_push))
            elif isinstance(op, inst.FunctionCallOp):
                raise ValueError(
                    "pop-push fusion not implemented for pre-lowering")
            elif isinstance(op, inst.PopOp):
                for var in inst.pattern_traverse(op.vars):
                    if var in waiting_for_pop:
                        new_instructions.append(inst.PopOp([var]))
                    else:
                        waiting_for_pop[var] = True
            else:
                raise ValueError("Unrecognized op in pop-push fusion", op)
        if waiting_for_pop:
            new_instructions.append(inst.PopOp(list(waiting_for_pop.keys())))
        block.instructions = new_instructions
    return program
예제 #2
0
def _is_determined(type_):
    for item in instructions.pattern_traverse(
            type_.tensors, leaf_type=instructions.TensorType):
        if item is None:
            return False
        if item.dtype is None:
            return False
    return True
예제 #3
0
 def pattern(self, item):
     # _MagicPattern is meant to be a friend class of ProgramBuilder.
     # pylint: disable=protected-access
     if item is not None:
         self._context._update_last_instruction(
             item.replace(vars_out=inst.pattern_map(str, self._pattern)))
     for var in inst.pattern_traverse(self._pattern):
         self._context._mark_defined(var)
예제 #4
0
def _run_at_batch_size_one(backend, function, args):
    """Executes the given function, checking in- and out-batch shape."""
    for arg in args:
        for type_ in instructions.pattern_traverse(
                ab_type_inference.type_of_pattern(arg, backend),
                leaf_type=instructions.TensorType):
            shape = type_.shape
            if len(shape) >= 1 and shape[0] != 1:
                msg = 'Expecting input batch dimension of size 1; got {} of shape {}.'
                raise ValueError(msg.format(arg, shape))
    result = function(*args)
    for item in instructions.pattern_traverse(result):
        for type_ in instructions.pattern_traverse(
                ab_type_inference.type_of_pattern(item, backend),
                leaf_type=instructions.TensorType):
            shape = type_.shape
            if len(shape) >= 1 and shape[0] != 1:
                msg = 'Expecting result batch dimension of size 1; got {} of shape {}.'
                raise ValueError(msg.format(item, shape))
    return result
예제 #5
0
def _add_incompatible_batch_dim(type_pat):
  """Adds a batch dim incompatible with all other known dims."""
  new_batch_dim = 2
  for tp in instructions.pattern_traverse(
      type_pat, leaf_type=instructions.TensorType):
    new_batch_dim = max(new_batch_dim, max((0,) + tp.shape) + 1)
  log_debug('using incompatible batch dim %d', new_batch_dim)
  def add_batch_dim_one_var(type_):
    return instructions.Type(instructions.pattern_map(
        lambda t: instructions.TensorType(t.dtype, (new_batch_dim,) + t.shape),
        type_.tensors, leaf_type=instructions.TensorType))
  return instructions.pattern_map(
      add_batch_dim_one_var, type_pat, leaf_type=instructions.Type)
예제 #6
0
    def call(self, function, vars_in, vars_out=None):
        """Registers a function call instruction.

    Example:
    ```
    ab = dsl.ProgramBuilder()

    # Define a function
    with ab.function(...) as func:
      ...
      # Call it (recursively)
      ab.var.thing = ab.call(func, ...)
      ...
    ```

    Args:
      function: The `instructions.Function` object representing the function to
        call.
      vars_in: Python strings giving the variables to pass in as inputs.
      vars_out: A pattern of Python strings, giving the auto-batched variable(s)
        to which to write the result of the call.  Defaults to the empty list.

    Raises:
      ValueError: If the call references undefined auto-batched variables.

    Returns:
      op: An `instructions.FunctionCallOp` representing the call.  If one
        subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`,
        that local gets added to the list of output variables.
    """
        for var in vars_in:
            if var not in self._var_defs:
                raise ValueError(
                    'Referencing undefined variable {}.'.format(var))
        self._prepare_for_instruction()
        if vars_out is None:
            vars_out = []
        call = inst.FunctionCallOp(function, _str_list(vars_in),
                                   inst.pattern_map(str, vars_out))
        self._blocks[-1].instructions.append(call)
        for var in inst.pattern_traverse(vars_out):
            self._mark_defined(var)
        return call
예제 #7
0
def _process_block(block, visited, inferred_types, backend):
  """Executes a pass of type inference on a single `Block`."""
  for op in block.instructions:
    log_debug('handle op {}'.format(op))
    if isinstance(op, instructions.PrimOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      types_in = [inferred_types[var] for var in op.vars_in]
      # Offer type hints for cases where we need to type non-Tensor literals.
      preferred_types_out = instructions.pattern_map(
          lambda var: inferred_types[var], op.vars_out)
      with _type_inferring():
        objs_out = backend.run_on_dummies(
            op.function, _add_incompatible_batch_dim(types_in))
      types_out = _strip_batch_dim(instructions.pattern_map2(
          lambda tp, val: type_of_pattern(val, backend, preferred_type=tp),
          preferred_types_out, objs_out, leaf_type=instructions.Type))
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update PrimOp vars_out')
    elif isinstance(op, instructions.FunctionCallOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      # First, bind op.vars_in to op.function.vars_in.
      types_in = [inferred_types[var].tensors for var in op.vars_in]
      _merge_vars(op.function.vars_in, types_in, inferred_types, backend,
                  log_message='init function vars_in')
      # Execute type inference.
      types_out = op.function.type_inference(types_in)
      for leaf in instructions.pattern_traverse(
          types_out, leaf_type=instructions.TensorType):
        if not isinstance(leaf, instructions.TensorType):
          msg = ('Expected function output type to be '
                 'a nested list or tuple of TensorType, found {}.').format(leaf)
          raise TypeError(msg)
      # To help with typing recursive base-case return literals, we seed
      # return_vars types before stepping into the function.
      _merge_vars(op.function.vars_out, types_out, inferred_types, backend,
                  log_message='update function vars_out')
      # Finally, update op.vars_out with the results of type inference.
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update FunctionCall vars_out')
      # Step into function. Note: it will only be visited once, if recursive.
      _process_graph(op.function.graph, visited, inferred_types, backend)
def _optimize_1(graph, live_out, alloc):
    """Optimize the variable allocation strategy for one CFG.

  Args:
    graph: `ControlFlowGraph` to traverse.
    live_out: Set of `str` variable names that are live out of this graph (i.e.,
      returned by the function this graph represents).
    alloc: Dictionary of allocation strategy deductions made so far.
      This is mutated; but no variable is moved to a cheaper strategy.
  """
    liveness_map = liveness.liveness_analysis(graph, set(live_out))
    if graph.exit_index() > 0:
        _variable_crosses_block_boundary(inst.pc_var, alloc)
    for i in range(graph.exit_index()):
        block = graph.block(i)
        for op, live_out in zip(block.instructions,
                                liveness_map[block].live_out_instructions):
            for varname in inst.pattern_traverse(_vars_read_by(op)):
                _variable_is_read(varname, alloc)
            if isinstance(op, inst.FunctionCallOp):
                callee_writes = _indirectly_writes(op.function)
                for varname in live_out - set(inst.pattern_flatten(
                        op.vars_out)):
                    # A variable only needs the conservative storage strategy if it
                    # crosses a call to some function that writes it (e.g., a recursive
                    # self-call).
                    if varname in callee_writes:
                        _variable_crosses_function_call_boundary(
                            varname, alloc)
                    else:
                        _variable_crosses_block_boundary(varname, alloc)
                # TODO(axch): Actually, the PC only needs a stack at this site if this
                # is not a tail call.
                _variable_crosses_function_call_boundary(inst.pc_var, alloc)
        if isinstance(block.terminator, inst.BranchOp):
            # TODO(axch): Actually, being read by BranchOp only implies
            # _variable_is_read.  However, the downstream VM doesn't know how to pop a
            # condition variable that is not needed after the BranchOp, so for now we
            # have to allocate a register for it.
            _variable_crosses_block_boundary(block.terminator.cond_var, alloc)
        for varname in liveness_map[block].live_out_of_block:
            _variable_crosses_block_boundary(varname, alloc)
예제 #9
0
    def return_(self, vars_out):
        """Records a function return instruction.

    Example:
    ```python
    ab = dsl.ProgramBuilder()

    with ab.function(...) as f:
      ...
      ab.var.result = ...
      ab.return_(ab.var.result)
    ```

    A `return_` command must occur at the top level of the function definition
    (not inside any `if_`s), and must be the last statement therein.  You can
    always achieve this by assigning to a dedicated variable for the answer
    where you would otherwise return (and massaging your control flow).

    Args:
      vars_out: Pattern of Python strings giving the auto-batched variables to
        return.

    Raises:
      ValueError: If invoked more than once in a function body, or if trying to
        return variables that have not been written to.
    """
        # Assume the return_ call is at the top level, and the last statement in the
        # body.  If return_ is nested, the terminator may be overwritten
        # incorrectly.  If return_ is followed by something else, extra instructions
        # may get inserted before the return (becaue return_ doesn't set up a Block
        # to catch them).
        self._prepare_for_instruction()
        for var in inst.pattern_traverse(vars_out):
            if var not in self._var_defs:
                raise ValueError(
                    'Returning undefined variable {}.'.format(var))
        if self._functions[-1].vars_out:
            raise ValueError(
                'Function body must have exactly one return_ statement')
        self._functions[-1].vars_out = inst.pattern_map(str, vars_out)
        self._blocks[-1].terminator = inst.halt_op()
예제 #10
0
    def testProgramProperties(self):
        def target(x):
            return x

        operator = tfp.experimental.mcmc.NoUTurnSampler(target,
                                                        step_size=0,
                                                        use_auto_batching=True)
        program = operator.autobatch_context.program_lowered(
            'evolve_trajectory')

        def full_var(var):
            return program.var_alloc[var] == inst.VariableAllocation.FULL

        # Check that the number of FULL variables doesn't accidentally grow.  This
        # is an equality rather than comparison to remind the maintainer to reduce
        # the expected number when they implement any pass that manages to reduce
        # the count.
        num_full_vars = 0
        for var in program.var_alloc:
            if full_var(var):
                num_full_vars += 1
        self.assertEqual(10, num_full_vars)

        # Check that the number of stack pushes doesn't accidentally grow.
        num_full_stack_pushes = 0
        for i in range(program.graph.exit_index()):
            block = program.graph.block(i)
            for op in block.instructions:
                if hasattr(op, 'vars_out'):
                    for var in inst.pattern_traverse(op.vars_out):
                        if full_var(var):
                            if (not hasattr(op, 'skip_push_mask')
                                    or var not in op.skip_push_mask):
                                num_full_stack_pushes += 1
            if isinstance(block.terminator, inst.PushGotoOp):
                if full_var(inst.pc_var):
                    num_full_stack_pushes += 1
        self.assertEqual(20, num_full_stack_pushes)
예제 #11
0
    def primop(self, f, vars_in=None, vars_out=None):
        """Records a primitive operation.

    Example:

    ```
    ab = dsl.ProgramBuilder()

    ab.var.five = ab.const(5)
    # Implicit output binding
    ab.var.ten = ab.primop(lambda five: five + five)
    # Explicit output binding
    ab.primop(lambda: (5, 10), vars_out=[ab.var.five, ab.var.ten])
    ```

    Args:
      f: A Python callable, the primitive operation to perform.  Can be
        an inline lambda expression in simple cases.  Must return a list or
        tuple of results, one for each intended output variable.
      vars_in: A list of Python strings, giving the auto-batched variables
        to pass into the callable when invoking it.  If absent, `primop`
        will try to infer it by inspecting the argument list of the callable
        and matching against variables bound in the local scope.
      vars_out: A pattern of Python strings, giving the auto-batched variable(s)
        to which to write the result of the callable.  Defaults to the empty
        list.

    Raises:
      ValueError: If the definition is invalid, if the primop references
        undefined auto-batched variables, or if auto-detection of input
        variables fails.

    Returns:
      op: An `instructions.PrimOp` instance representing this operation.  If one
        subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`,
        that local becomes the output pattern.
    """
        self._prepare_for_instruction()
        if vars_out is None:
            vars_out = []
        if vars_in is None:
            # Deduce the intended variable names from the argument list of the callee.
            # Expected use case: the callee is an inline lambda expression.
            args, varargs, keywords, _ = inspect.getargspec(f)
            vars_in = []
            for arg in args:
                if arg in self._locals:
                    vars_in.append(self._locals[arg])
                else:
                    raise ValueError(
                        'Auto-referencing unbound variable {}.'.format(arg))
            if varargs is not None:
                raise ValueError('Varargs are not supported for primops')
            if keywords is not None:
                raise ValueError('kwargs are not supported for primops')
        for var in vars_in:
            if var not in self._var_defs:
                raise ValueError(
                    'Referencing undefined variable {}.'.format(var))
        prim = inst.prim_op(_str_list(vars_in),
                            inst.pattern_map(str, vars_out), f)
        self._blocks[-1].instructions.append(prim)
        for var in inst.pattern_traverse(vars_out):
            self._mark_defined(var)
        return prim
예제 #12
0
def _lower_function_calls_1(builder,
                            graph,
                            defined_in,
                            live_out,
                            function=True):
    """Lowers one function body, destructively.

  Mutates the given `ControlFlowGraphBuilder`, inserting `Block`s
  representing the new body.  Some of these may be the same as some
  `Block`s in the input `graph`, mutated; others may be newly
  allocated.

  Args:
    builder: `ControlFlowGraphBuilder` constructing the answer.
    graph: The `ControlFlowGraph` to lower.
    defined_in: A Python list of `str`.  The set of variables that
      are defined on entry to this `graph`.
    live_out: A Python list of `str`.  The set of variables that are
      live on exit from this `graph`.
    function: Python `bool`.  If `True` (the default), assume this is
      a `Function` body and convert an "exit" transfer into
      `IndirectGotoOp`; otherwise leave it as (`Program`) "exit".

  Raises:
    ValueError: If an invalid instruction is encountered, if a live
      variable is undefined, if different paths into a `Block` cause
      different sets of variables to be defined, or if trying to lower
      function calls in a program that already has `IndirectGotoOp`
      instructions (they confuse the liveness analysis).
  """
    liveness_map = liveness.liveness_analysis(graph, set(live_out))
    defined_map = _definedness_analysis(graph, defined_in, liveness_map)
    for i in range(graph.exit_index()):
        block = graph.block(i)
        old_instructions = block.instructions
        # Resetting block.instructions here because we will build up the
        # list of new ones in place (via the `builder`).
        block.instructions = []
        builder.append_block(block)
        builder.maybe_add_pop(defined_map[block].defined_into_block,
                              liveness_map[block].live_into_block)
        for op_i, (op, defined_out, live_out) in enumerate(
                zip(old_instructions,
                    defined_map[block].defined_out_instructions,
                    liveness_map[block].live_out_instructions)):
            if isinstance(op, inst.PrimOp):
                for name in inst.pattern_traverse(op.vars_in):
                    if name in inst.pattern_flatten(op.vars_out):
                        # Why not?  Because the stack discipline we are trying to
                        # implement calls for popping variables as soon as they
                        # become dead.  Now, if a PrimOp writes to the same
                        # variable as it reads, the old version of that variable
                        # dies.  Where to put the PopOp?  Before the PrimOp is no
                        # good -- it still needs to be read.  After the PrimOp is
                        # no good either -- it will pop the output, not the input.
                        # Various solutions to this problem are possible, such as
                        # adding a "drop the second-top element of this stack"
                        # instruction, or orchestrating the pushes and pops
                        # directly in the interpreter, but for now the simplest
                        # thing is to just forbid this situation.
                        # Fixing this is b/118884528.
                        msg = 'Cannot lower PrimOp that writes to its own input {}.'
                        raise ValueError(msg.format(name))
                builder.append_instruction(op)
                builder.maybe_add_pop(defined_out, live_out)
            elif isinstance(op, inst.FunctionCallOp):
                names_pushed_here = inst.pattern_flatten(op.vars_out)
                for name in inst.pattern_traverse(op.vars_in):
                    if name in names_pushed_here:
                        # For the same reason as above.
                        # Fixing this is b/118884528.
                        msg = 'Cannot lower FunctionCallOp that writes to its own input {}.'
                        raise ValueError(msg.format(name))
                # The variables that were defined on entry to this instruction (i.e.,
                # not pushed here) but are not live out don't need to remain on their
                # stacks when the callee is entered.
                defined_in = defined_out.difference(names_pushed_here)
                to_pop = defined_in.difference(live_out)
                for new_op in _function_entry_stack_ops(op, to_pop):
                    builder.append_instruction(new_op)
                if (op_i == len(old_instructions) - 1
                        and _optimizable_tail_call(op, builder.cur_block())):
                    builder.end_block_with_tail_call(
                        op.function.graph.block(0))
                    # The check that the tail call is optimizable is equivalent to
                    # checking that the push-pop pair below would do nothing.
                else:
                    builder.split_block(op.function.graph.block(0))
                    builder.append_instruction(
                        # These extra levels of list protect me (I hope) from the
                        # auto-unpacking in the implementation of push_op, in the case of
                        # a function returning exactly one Tensor.
                        inst.push_op([op.function.vars_out], [op.vars_out]))
                    builder.append_instruction(
                        inst.PopOp(inst.pattern_flatten(op.function.vars_out)))
                    # The only way this would actually add a pop is if some name written
                    # by this call was a dummy variable.
                    builder.maybe_add_pop(frozenset(names_pushed_here),
                                          live_out)
            elif isinstance(op, (inst.PopOp)):
                # Presumably, lowering is applied before any `PopOp`s are present.  That
                # said, simply propagating them is sound.  (But see the `PopOp` case in
                # `liveness_analysis`.)
                builder.append_instruction(op)
            else:
                raise ValueError('Invalid instruction in block {}.'.format(op))
        if function:
            builder.maybe_adjust_terminator()