def fuse_pop_push(program): """Fuses pop+push sequences in the given `Program`. A stack pop followed by a stack push (with no intervening read) is equivalent to just updating the top of the stack. The latter is more efficient for FULL variables, because it just updates the cache for the top, and avoids gathering from and scattering to the backing stack Tensor. This pass mutates the `ControlFlowGraph` of the input `Program` to convert pop+push sequences into updates. The pass will work despite intervening instructions that interact with other Variables, but will not cross basic block boundaries. As a side-effect, the pass moves non-optimized pops to the last place in their basic block where they are still sound. This has no effect on the runtime behavior of the program. Args: program: A lowered `Program` whose pop+push sequences to fuse. `Block`s in the program may be mutated. Returns: fused: A `Program` with statically redundant pop+push sequences eliminated in favor of `PrimOp`s with non-trivial `skip_push_mask` fields. Raises: ValueError: If the input `Program` has not been lowered (i.e., contains `FunctionCallOp`), or is ill-formed (e.g., contains invalid instructions). """ for i in range(program.graph.exit_index()): block = program.graph.block(i) new_instructions = [] waiting_for_pop = collections.OrderedDict([]) for op in block.instructions: if isinstance(op, inst.PrimOp): skip_push = set() for var in inst.pattern_traverse(op.vars_in): if var in waiting_for_pop: new_instructions.append(inst.PopOp([var])) del waiting_for_pop[var] for var in inst.pattern_traverse(op.vars_out): if var in waiting_for_pop: skip_push.add(var) del waiting_for_pop[var] new_instructions.append( inst.prim_op(op.vars_in, op.vars_out, op.function, skip_push)) elif isinstance(op, inst.FunctionCallOp): raise ValueError("pop-push fusion not implemented for pre-lowering") elif isinstance(op, inst.PopOp): for var in inst.pattern_traverse(op.vars): if var in waiting_for_pop: new_instructions.append(inst.PopOp([var])) else: waiting_for_pop[var] = True else: raise ValueError("Unrecognized op in pop-push fusion", op) if waiting_for_pop: new_instructions.append(inst.PopOp(list(waiting_for_pop.keys()))) block.instructions = new_instructions return program
def _optimize_1(graph, live_out, alloc): """Optimize the variable allocation strategy for one CFG. Args: graph: `ControlFlowGraph` to traverse. live_out: Set of `str` variable names that are live out of this graph (i.e., returned by the function this graph represents). alloc: Dictionary of allocation strategy deductions made so far. This is mutated; but no variable is moved to a cheaper strategy. """ liveness_map = liveness.liveness_analysis(graph, set(live_out)) if graph.exit_index() > 0: _variable_crosses_block_boundary(inst.pc_var, alloc) for i in range(graph.exit_index()): block = graph.block(i) for op, live_out in zip( block.instructions, liveness_map[block].live_out_instructions): for varname in inst.pattern_traverse(_vars_read_by(op)): _variable_is_read(varname, alloc) if isinstance(op, inst.FunctionCallOp): for varname in live_out - set(inst.pattern_flatten(op.vars_out)): # TODO(axch): Technically a variable only needs the conservative # storage strategy if it crosses a call to some function that writes # it (e.g., a recursive self-call). Checking for that here would # require traversing the call graph. _variable_crosses_function_call_boundary(varname, alloc) _variable_crosses_function_call_boundary(inst.pc_var, alloc) if isinstance(block.terminator, inst.BranchOp): # TODO(axch): Actually, being read by BranchOp only implies # _variable_is_read. However, the downstream VM doesn't know how to pop a # condition variable that is not needed after the BranchOp, so for now we # have to allocate a register for it. _variable_crosses_block_boundary(block.terminator.cond_var, alloc) for varname in liveness_map[block].live_out_of_block: _variable_crosses_block_boundary(varname, alloc)
def _is_determined(type_): for item in instructions.pattern_traverse( type_.tensors, leaf_type=instructions.TensorType): if item is None: return False if item.dtype is None: return False return True
def pattern(self, item): # _MagicPattern is meant to be a friend class of ProgramBuilder. # pylint: disable=protected-access if item is not None: self._context._update_last_instruction( item.replace(vars_out=inst.pattern_map(str, self._pattern))) for var in inst.pattern_traverse(self._pattern): self._context._mark_defined(var)
def _run_at_batch_size_one(backend, function, args): """Executes the given function, checking in- and out-batch shape.""" for arg in args: for type_ in instructions.pattern_traverse( ab_type_inference.type_of_pattern(arg, backend), leaf_type=instructions.TensorType): shape = type_.shape if len(shape) >= 1 and shape[0] != 1: msg = 'Expecting input batch dimension of size 1; got {} of shape {}.' raise ValueError(msg.format(arg, shape)) result = function(*args) for item in instructions.pattern_traverse(result): for type_ in instructions.pattern_traverse( ab_type_inference.type_of_pattern(item, backend), leaf_type=instructions.TensorType): shape = type_.shape if len(shape) >= 1 and shape[0] != 1: msg = 'Expecting result batch dimension of size 1; got {} of shape {}.' raise ValueError(msg.format(item, shape)) return result
def _add_incompatible_batch_dim(type_pat): """Adds a batch dim incompatible with all other known dims.""" new_batch_dim = 2 for tp in instructions.pattern_traverse(type_pat, leaf_type=instructions.TensorType): new_batch_dim = max(new_batch_dim, max((0, ) + tp.shape) + 1) log_debug('using incompatible batch dim %d', new_batch_dim) def add_batch_dim_one_var(type_): return instructions.Type( instructions.pattern_map( lambda t: instructions.TensorType(t.dtype, (new_batch_dim, ) + t.shape), type_.tensors, leaf_type=instructions.TensorType)) return instructions.pattern_map(add_batch_dim_one_var, type_pat, leaf_type=instructions.Type)
def call(self, function, vars_in, vars_out=None): """Registers a function call instruction. Example: ``` ab = dsl.ProgramBuilder() # Define a function with ab.function(...) as func: ... # Call it (recursively) ab.var.thing = ab.call(func, ...) ... ``` Args: function: The `instructions.Function` object representing the function to call. vars_in: Python strings giving the variables to pass in as inputs. vars_out: A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the call. Defaults to the empty list. Raises: ValueError: If the call references undefined auto-batched variables. Returns: op: An `instructions.FunctionCallOp` representing the call. If one subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`, that local gets added to the list of output variables. """ for var in vars_in: if var not in self._var_defs: raise ValueError( 'Referencing undefined variable {}.'.format(var)) self._prepare_for_instruction() if vars_out is None: vars_out = [] call = inst.FunctionCallOp(function, _str_list(vars_in), inst.pattern_map(str, vars_out)) self._blocks[-1].instructions.append(call) for var in inst.pattern_traverse(vars_out): self._mark_defined(var) return call
def return_(self, vars_out): """Records a function return instruction. Example: ```python ab = dsl.ProgramBuilder() with ab.function(...) as f: ... ab.var.result = ... ab.return_(ab.var.result) ``` A `return_` command must occur at the top level of the function definition (not inside any `if_`s), and must be the last statement therein. You can always achieve this by assigning to a dedicated variable for the answer where you would otherwise return (and massaging your control flow). Args: vars_out: Pattern of Python strings giving the auto-batched variables to return. Raises: ValueError: If invoked more than once in a function body, or if trying to return variables that have not been written to. """ # Assume the return_ call is at the top level, and the last statement in the # body. If return_ is nested, the terminator may be overwritten # incorrectly. If return_ is followed by something else, extra instructions # may get inserted before the return (becaue return_ doesn't set up a Block # to catch them). self._prepare_for_instruction() for var in inst.pattern_traverse(vars_out): if var not in self._var_defs: raise ValueError( 'Returning undefined variable {}.'.format(var)) if self._functions[-1].vars_out: raise ValueError( 'Function body must have exactly one return_ statement') self._functions[-1].vars_out = inst.pattern_map(str, vars_out) self._blocks[-1].terminator = inst.halt_op()
def testProgramProperties(self): def target(x): return x operator = nuts.NoUTurnSampler(target, step_size=0, use_auto_batching=True) program = operator.autobatch_context.program_lowered( 'evolve_trajectory') def full_var(var): return program.var_alloc[var] == inst.VariableAllocation.FULL # Check that the number of FULL variables doesn't accidentally grow. This # is an equality rather than comparison to remind the maintainer to reduce # the expected number when they implement any pass that manages to reduce # the count. num_full_vars = 0 for var in program.var_alloc: if full_var(var): num_full_vars += 1 self.assertEqual(10, num_full_vars) # Check that the number of stack pushes doesn't accidentally grow. num_full_stack_pushes = 0 for i in range(program.graph.exit_index()): block = program.graph.block(i) for op in block.instructions: if hasattr(op, 'vars_out'): for var in inst.pattern_traverse(op.vars_out): if full_var(var): if (not hasattr(op, 'skip_push_mask') or var not in op.skip_push_mask): num_full_stack_pushes += 1 if isinstance(block.terminator, inst.PushGotoOp): if full_var(inst.pc_var): num_full_stack_pushes += 1 self.assertEqual(19, num_full_stack_pushes)
def _lower_function_calls_1( builder, graph, defined_in, live_out, function=True): """Lowers one function body, destructively. Mutates the given `ControlFlowGraphBuilder`, inserting `Block`s representing the new body. Some of these may be the same as some `Block`s in the input `graph`, mutated; others may be newly allocated. Args: builder: `ControlFlowGraphBuilder` constructing the answer. graph: The `ControlFlowGraph` to lower. defined_in: A Python list of `str`. The set of variables that are defined on entry to this `graph`. live_out: A Python list of `str`. The set of variables that are live on exit from this `graph`. function: Python `bool`. If `True` (the default), assume this is a `Function` body and convert an "exit" transfer into `IndirectGotoOp`; otherwise leave it as (`Program`) "exit". Raises: ValueError: If an invalid instruction is encountered, if a live variable is undefined, if different paths into a `Block` cause different sets of variables to be defined, or if trying to lower function calls in a program that already has `IndirectGotoOp` instructions (they confuse the liveness analysis). """ liveness_map = liveness.liveness_analysis(graph, set(live_out)) defined_map = _definedness_analysis(graph, defined_in, liveness_map) for i in range(graph.exit_index()): block = graph.block(i) old_instructions = block.instructions # Resetting block.instructions here because we will build up the # list of new ones in place (via the `builder`). block.instructions = [] builder.append_block(block) builder.maybe_add_pop( defined_map[block].defined_into_block, liveness_map[block].live_into_block) for op_i, (op, defined_out, live_out) in enumerate(zip( old_instructions, defined_map[block].defined_out_instructions, liveness_map[block].live_out_instructions)): if isinstance(op, inst.PrimOp): for name in inst.pattern_traverse(op.vars_in): if name in inst.pattern_flatten(op.vars_out): # Why not? Because the stack discipline we are trying to # implement calls for popping variables as soon as they # become dead. Now, if a PrimOp writes to the same # variable as it reads, the old version of that variable # dies. Where to put the PopOp? Before the PrimOp is no # good -- it still needs to be read. After the PrimOp is # no good either -- it will pop the output, not the input. # Various solutions to this problem are possible, such as # adding a "drop the second-top element of this stack" # instruction, or orchestrating the pushes and pops # directly in the interpreter, but for now the simplest # thing is to just forbid this situation. msg = 'Cannot lower PrimOp that writes to its own input {}.' raise ValueError(msg.format(name)) builder.append_instruction(op) builder.maybe_add_pop(defined_out, live_out) elif isinstance(op, inst.FunctionCallOp): names_pushed_here = inst.pattern_flatten(op.vars_out) for name in inst.pattern_traverse(op.vars_in): if name in names_pushed_here: # For the same reason as above. msg = 'Cannot lower FunctionCallOp that writes to its own input {}.' raise ValueError(msg.format(name)) builder.append_instruction( # Some of the pushees (op.function.vars_in) may be in scope if this # is a self-call. They may therefore overlap with op.vars_in. If # so, those values will be copied and/or duplicated. Insofar as # op.vars_in become dead, some of this will be undone by the # following pop. This is wasteful but sound. inst.push_op(op.vars_in, op.function.vars_in)) builder.maybe_add_pop( # Pop names defined on entry now, to save a stack frame in the # function call. Can't pop names defined by this call # because they haven't been pushed yet. defined_out.difference(names_pushed_here), live_out) if (op_i == len(old_instructions) - 1 and _optimizable_tail_call(op, builder.cur_block())): builder.end_block_with_tail_call(op.function.graph.block(0)) # The check that the tail call is optimizable is equivalent to # checking that the push-pop pair below would do nothing. else: builder.split_block(op.function.graph.block(0)) builder.append_instruction( # These extra levels of list protect me (I hope) from the # auto-unpacking in the implementation of push_op, in the case of # a function returning exactly one Tensor. inst.push_op([op.function.vars_out], [op.vars_out])) builder.append_instruction( inst.PopOp(inst.pattern_flatten(op.function.vars_out))) # The only way this would actually add a pop is if some name written # by this call was a dummy variable. builder.maybe_add_pop(frozenset(names_pushed_here), live_out) elif isinstance(op, (inst.PopOp)): # Presumably, lowering is applied before any `PopOp`s are present. That # said, simply propagating them is sound. (But see the `PopOp` case in # `liveness_analysis`.) builder.append_instruction(op) else: raise ValueError('Invalid instruction in block {}.'.format(op)) if function: builder.maybe_adjust_terminator()
def _process_block(block, visited, inferred_types, backend): """Executes a pass of type inference on a single `Block`.""" for op in block.instructions: log_debug('handle op {}'.format(op)) if isinstance(op, instructions.PrimOp): if not all( _is_determined(inferred_types[var]) for var in op.vars_in): continue types_in = [inferred_types[var] for var in op.vars_in] # Offer type hints for cases where we need to type non-Tensor literals. preferred_types_out = instructions.pattern_map( lambda var: inferred_types[var], op.vars_out) with _type_inferring(): objs_out = backend.run_on_dummies( op.function, _add_incompatible_batch_dim(types_in)) types_out = _strip_batch_dim( instructions.pattern_map2(lambda tp, val: type_of_pattern( val, backend, preferred_type=tp), preferred_types_out, objs_out, leaf_type=instructions.Type)) _merge_vars(op.vars_out, types_out, inferred_types, backend, log_message='update PrimOp vars_out') elif isinstance(op, instructions.FunctionCallOp): if not all( _is_determined(inferred_types[var]) for var in op.vars_in): continue # First, bind op.vars_in to op.function.vars_in. types_in = [inferred_types[var].tensors for var in op.vars_in] _merge_vars(op.function.vars_in, types_in, inferred_types, backend, log_message='init function vars_in') # Execute type inference. types_out = op.function.type_inference(types_in) for leaf in instructions.pattern_traverse( types_out, leaf_type=instructions.TensorType): if not isinstance(leaf, instructions.TensorType): msg = ('Expected function output type to be ' 'a nested list or tuple of TensorType, found {}.' ).format(leaf) raise TypeError(msg) # To help with typing recursive base-case return literals, we seed # return_vars types before stepping into the function. _merge_vars(op.function.vars_out, types_out, inferred_types, backend, log_message='update function vars_out') # Finally, update op.vars_out with the results of type inference. _merge_vars(op.vars_out, types_out, inferred_types, backend, log_message='update FunctionCall vars_out') # Step into function. Note: it will only be visited once, if recursive. _process_graph(op.function.graph, visited, inferred_types, backend)
def primop(self, f, vars_in=None, vars_out=None): """Records a primitive operation. Example: ``` ab = dsl.ProgramBuilder() ab.var.five = ab.const(5) # Implicit output binding ab.var.ten = ab.primop(lambda five: five + five) # Explicit output binding ab.primop(lambda: (5, 10), vars_out=[ab.var.five, ab.var.ten]) ``` Args: f: A Python callable, the primitive operation to perform. Can be an inline lambda expression in simple cases. Must return a list or tuple of results, one for each intended output variable. vars_in: A list of Python strings, giving the auto-batched variables to pass into the callable when invoking it. If absent, `primop` will try to infer it by inspecting the argument list of the callable and matching against variables bound in the local scope. vars_out: A pattern of Python strings, giving the auto-batched variable(s) to which to write the result of the callable. Defaults to the empty list. Raises: ValueError: If the definition is invalid, if the primop references undefined auto-batched variables, or if auto-detection of input variables fails. Returns: op: An `instructions.PrimOp` instance representing this operation. If one subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`, that local becomes the output pattern. """ self._prepare_for_instruction() if vars_out is None: vars_out = [] if vars_in is None: # Deduce the intended variable names from the argument list of the callee. # Expected use case: the callee is an inline lambda expression. args, varargs, keywords, _ = inspect.getargspec(f) vars_in = [] for arg in args: if arg in self._locals: vars_in.append(self._locals[arg]) else: raise ValueError( 'Auto-referencing unbound variable {}.'.format(arg)) if varargs is not None: raise ValueError('Varargs are not supported for primops') if keywords is not None: raise ValueError('kwargs are not supported for primops') for var in vars_in: if var not in self._var_defs: raise ValueError( 'Referencing undefined variable {}.'.format(var)) prim = inst.prim_op(_str_list(vars_in), inst.pattern_map(str, vars_out), f) self._blocks[-1].instructions.append(prim) for var in inst.pattern_traverse(vars_out): self._mark_defined(var) return prim