Exemple #1
0
def lower_function_calls(program):
  """Lowers a `Program` that may have (recursive) FunctionCallOp instructions.

  Mutates the `ControlFlowGraph` of the input program in place.  After
  lowering, the result CFG

  - Has no `FunctionCallOp` instructions

  - Obeys a stack discipline

  What is the stack discipline?  Every function body becomes a CFG
  subset that:

  - Never transfers control in except to the first block
    (corresponding to being called), or to a block stored with
    `PushGotoOp` (corresponding to a subroutine returning)

  - Never transfers control out except with `IndirectGotoOp`
    (corresponding to returning), or with a `PushGotoOp`
    (corresponding to calling a subroutine)

  - Every path through the graph has the following effect on the
    variable stacks:

    - The formal parameters receive exactly one net pop

    - The return variables receive exactly one net push

    - All other variable stacks are left as they are

    - No data is read except the top frames of the formal parameter
      stacks

  Why mutate in place?  Because tying the knot in the result seemed
  too hard without an explicit indirection between `Block`s and
  references thereto in various `Op`s.  Specifically, when building a
  new CFG to replicate the structure of an existing one, it is
  necessary to allocate `Block`s to serve as the targets of all
  `BranchOp`, `GotoOp` (and `FunctionCallOp`) before building those
  `Op`s, and then correctly reuse those `Block`s when processing said
  targets.  With an explicit indirection, however, it would have been
  possible to reuse the same `Label`s, simply creating a new mapping
  from them to `Block`s.

  Note that the semantics assumed by this transformation is that the
  CFGs being transformed do not use variable stacks internally, but
  they will only be used to implement the function sequence when
  function calls are lowered.  This semantics licenses placing
  `PopOp`s to enforce a stack discipline for `FunctionCallOp`s.

  Args:
    program: A `Program` whose function calls to lower.  `Block`s in
      the program may be mutated.

  Returns:
    lowered: A `Program` that defines no `Function`s and does not use the
      `FunctionCallOp` instruction.  May share structure with the input
      `program`.

  Raises:
    ValueError: If an invalid instruction is encountered, if a live
      variable is undefined, if different paths into a `Block` cause
      different sets of variables to be defined, or if trying to lower
      function calls in a program that already has loops (within a
      `Function` body) or `IndirectGotoOp` instructions (they confuse
      the liveness analysis).
  """
  builder = ControlFlowGraphBuilder()
  _lower_function_calls_1(
      builder, program.graph, program.vars_in,
      inst.pattern_flatten(program.vars_out), function=False)
  for func in program.functions:
    _lower_function_calls_1(
        builder, func.graph, func.vars_in, inst.pattern_flatten(func.vars_out))
  return inst.Program(
      builder.control_flow_graph(), [],
      program.var_defs, program.vars_in, program.vars_out,
      program.var_alloc)
Exemple #2
0
def _lower_function_calls_1(
    builder, graph, defined_in, live_out, function=True):
  """Lowers one function body, destructively.

  Mutates the given `ControlFlowGraphBuilder`, inserting `Block`s
  representing the new body.  Some of these may be the same as some
  `Block`s in the input `graph`, mutated; others may be newly
  allocated.

  Args:
    builder: `ControlFlowGraphBuilder` constructing the answer.
    graph: The `ControlFlowGraph` to lower.
    defined_in: A Python list of `str`.  The set of variables that
      are defined on entry to this `graph`.
    live_out: A Python list of `str`.  The set of variables that are
      live on exit from this `graph`.
    function: Python `bool`.  If `True` (the default), assume this is
      a `Function` body and convert an "exit" transfer into
      `IndirectGotoOp`; otherwise leave it as (`Program`) "exit".

  Raises:
    ValueError: If an invalid instruction is encountered, if a live
      variable is undefined, if different paths into a `Block` cause
      different sets of variables to be defined, or if trying to lower
      function calls in a program that already has `IndirectGotoOp`
      instructions (they confuse the liveness analysis).
  """
  liveness_map = liveness.liveness_analysis(graph, set(live_out))
  defined_map = _definedness_analysis(graph, defined_in, liveness_map)
  for i in range(graph.exit_index()):
    block = graph.block(i)
    old_instructions = block.instructions
    # Resetting block.instructions here because we will build up the
    # list of new ones in place (via the `builder`).
    block.instructions = []
    builder.append_block(block)
    builder.maybe_add_pop(
        defined_map[block].defined_into_block,
        liveness_map[block].live_into_block)
    for op, defined_out, live_out in zip(
        old_instructions,
        defined_map[block].defined_out_instructions,
        liveness_map[block].live_out_instructions):
      if isinstance(op, inst.PrimOp):
        for name in inst.pattern_traverse(op.vars_in):
          if name in inst.pattern_flatten(op.vars_out):
            # Why not?  Because the stack discipline we are trying to
            # implement calls for popping variables as soon as they
            # become dead.  Now, if a PrimOp writes to the same
            # variable as it reads, the old version of that variable
            # dies.  Where to put the PopOp?  Before the PrimOp is no
            # good -- it still needs to be read.  After the PrimOp is
            # no good either -- it will pop the output, not the input.
            # Various solutions to this problem are possible, such as
            # adding a "drop the second-top element of this stack"
            # instruction, or orchestrating the pushes and pops
            # directly in the interpreter, but for now the simplest
            # thing is to just forbid this situation.
            msg = 'Cannot lower PrimOp that writes to its own input {}.'
            raise ValueError(msg.format(name))
        builder.append_instruction(op)
        builder.maybe_add_pop(defined_out, live_out)
      elif isinstance(op, inst.FunctionCallOp):
        names_pushed_here = inst.pattern_flatten(op.vars_out)
        for name in inst.pattern_traverse(op.vars_in):
          if name in names_pushed_here:
            # For the same reason as above.
            msg = 'Cannot lower FunctionCallOp that writes to its own input {}.'
            raise ValueError(msg.format(name))
        builder.append_instruction(
            # Some of the pushees (op.function.vars_in) may be in scope if this
            # is a self-call.  They may therefore overlap with op.vars_in.  If
            # so, those values will be copied and/or duplicated.  Insofar as
            # op.vars_in become dead, some of this will be undone by the
            # following pop.  This is wasteful but sound.
            inst.push_op(op.vars_in, op.function.vars_in))
        builder.maybe_add_pop(
            # Pop names defined on entry now, to save a stack frame in the
            # function call.  Can't pop names defined by this call
            # because they haven't been pushed yet.
            defined_out.difference(names_pushed_here), live_out)
        builder.split_block(op.function.graph.block(0))
        builder.append_instruction(
            # These extra levels of list protect me (I hope) from the
            # auto-unpacking in the implementation of push_op, in the case of a
            # function returning exactly one Tensor.
            inst.push_op([op.function.vars_out], [op.vars_out]))
        builder.append_instruction(
            inst.PopOp(inst.pattern_flatten(op.function.vars_out)))
        # The only way this would actually add a pop is if some name written by
        # this call was a dummy variable.
        builder.maybe_add_pop(frozenset(names_pushed_here), live_out)
      elif isinstance(op, (inst.PopOp)):
        # Presumably, lowering is applied before any `PopOp`s are present.  That
        # said, simply propagating them is sound.  (But see the `PopOp` case in
        # `liveness_analysis`.)
        builder.append_instruction(op)
      else:
        raise ValueError('Invalid instruction in block {}.'.format(op))
    if function:
      builder.maybe_adjust_terminator()
def optimize(program):
    """Optimizes a `Program`'s variable allocation strategy.

  The variable allocation strategies determine how much memory the `Program`
  consumes, and how costly its memory access operations are (see
  `instructions.VariableAllocation`).  In general, a variable holding data with
  a longer or more complex lifetime will need a more expensive storage strategy.
  This analysis examines variables' liveness and opportunistically selects
  inexpensive sound allocation strategies.

  Specifically, the algorithm is to:
  - Run liveness analysis to determine the lifespan of each variable.
  - Assume optimistically that no variable needs to be stored at all
    (`instructions.VariableAllocation.NULL`).
  - Traverse the instructions and pattern-match conditions that require
    some storage:
    - If a variable is read by an instruction, it must be at least
      `instructions.VariableAllocation.TEMPORARY`.
    - If a variable is live out of some block (i.e., crosses a block boundary),
      it must be at least `instructions.VariableAllocation.REGISTER`.  This is
      because temporaries do not appear in the loop state in `execute`.
    - If a variable is alive across a call to an autobatched `Function`, it must
      be `instructions.VariableAllocation.FULL`, because that `Function` may
      push values to it that must not overwrite the value present at the call
      point.  (This can be improved by examining the call graph to see whether
      the callee really does push values to this variable, but that's future
      work.)

  Args:
    program: `Program` to optimize.

  Returns:
    program: A newly allocated `Program` with the same semantics but possibly
      different allocation strategies for some (or all) variables.  Each new
      strategy may be more efficient than the input `Program`'s allocation
      strategy for that variable (if the analysis can prove it safe), but will
      not be less efficient.
  """
    alloc = {
        var: inst.VariableAllocation.NULL
        for var in program.var_defs.keys()
    }
    # The program counter is always read
    _variable_is_read(inst.pc_var, alloc)
    _optimize_1(program.graph, inst.pattern_flatten(program.vars_out), alloc)
    for varname in program.vars_in:
        # Because there is a while_loop iteration between the inputs and the first
        # block.
        _variable_crosses_block_boundary(varname, alloc)
    for func in program.functions:
        _optimize_1(func.graph, inst.pattern_flatten(func.vars_out), alloc)
        for varname in func.vars_in:
            _variable_crosses_block_boundary(varname, alloc)
    null_vars = [
        k for k, v in six.iteritems(alloc) if v is inst.VariableAllocation.NULL
    ]
    if null_vars:
        logging.warning(
            'Found variables with NULL allocation. These are written '
            'but never read: %s', null_vars)
    return program.replace(var_alloc=alloc)
Exemple #4
0
def _interpret(program, mask, backend, block_code_cache, *inputs):
    """Worker function for `execute`; operates under a mask."""
    environment = inst.Environment.initialize(backend, program.var_alloc,
                                              program.var_defs, 0,
                                              backend.batch_size(mask))
    for var, inp in inst.pattern_zip(program.vars_in, inputs):
        environment[var] = environment.push(var, inp, mask)
    program_counter = 0  # Index of initial block
    queue = ExecutionQueue(backend)
    while program_counter != program.graph.exit_index():
        block = program.graph.block(program_counter)
        for split_idx, split in enumerate(_split_fn_calls(block.instructions)):
            if isinstance(split, inst.FunctionCallOp):
                op = split
                if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                        or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                    raise ValueError(
                        'FunctionCallOp reads or writes program counter: {}'.
                        format(op))
                inputs = [
                    inst.pattern_map(environment.read, var_pat)
                    for var_pat in op.vars_in
                ]
                # Option: Could gather+scatter at function boundaries.  Then the inner
                # interpreter would not need to accept the mask, but would need to
                # recompute the batch size and make a new mask of all ones.
                outputs = _invoke_fun(program, mask, backend, block_code_cache,
                                      op.function, inputs)
                new_vars = [(varname, environment.push(varname, output, mask))
                            for varname, output in inst.pattern_zip(
                                op.vars_out, outputs)]
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=new_vars)
            else:  # This split is not a FunctionCallOp.
                block_code_key = (id(program.graph), program_counter,
                                  split_idx)
                if block_code_key not in block_code_cache:
                    logging.info('Fill block cache for block %s',
                                 block_code_key)
                    varnames = inst.extract_referenced_variables(split)
                    code = backend.wrap_straightline_callable(
                        functools.partial(_run_straightline, split,
                                          environment.backend))
                    block_code_cache[block_code_key] = (varnames, code)
                else:
                    logging.info('Use cached code for block %s',
                                 block_code_key)
                varnames, code = block_code_cache[block_code_key]
                filtered_env = dict(
                    {  # Only pass variables relevant to these ops
                        k: v
                        for k, v in six.iteritems(environment.env_dict)
                        if k in varnames
                    })
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=code(filtered_env, mask))
        op = block.terminator
        if isinstance(op, inst.BranchOp):
            if inst.pc_var == op.cond_var:
                raise ValueError('Branching on program counter: {}'.format(op))
            condition = environment.read(op.cond_var)
            true_index = program.graph.block_index(op.true_block)
            false_index = program.graph.block_index(op.false_block)
            queue.enqueue(true_index, mask & condition)
            queue.enqueue(false_index, mask & ~condition)
        elif isinstance(op, inst.GotoOp):
            next_index = program.graph.block_index(op.block)
            queue.enqueue(next_index, mask)
        else:
            raise TypeError('Unexpected op type: {}'.format(type(op)))
        program_counter, mask = queue.dequeue()
    # assert the queue is now empty
    return inst.pattern_map(environment.read, program.vars_out)
Exemple #5
0
def _definedness_analysis(graph, defined_in, liveness_map):
    """Computes the defined and live variables.

  Specifically, for each op in each `Block`, computes the set of
  variables that are both defined coming out of it, and live coming
  out of the previous instruction.  The purpose of this analysis is to
  compute where `_lower_function_calls_1` should put `PopOp`s to
  enforce the stack discipline: The difference between the set
  computed here and the set of variables that are live out of each
  instruction is exactly the set for which a `PopOp` should be added.

  Why compute liveness and definedness jointly, rather than separately
  and then intersect them?  Because the purpose is to compute where to
  put the `PopOp`s, so that at the point at which any defined variable
  becomes dead, there is exactly one `PopOp`.  Placing such a `PopOp`
  will cause the variable in question to cease being defined, so this
  pass removes it from the defined set in anticipation thereof.

  Note that the semantics assumed by this analysis is that the control
  flow graph being analyzed does not use variable stacks internally,
  but they will only be used to implement the function sequence when
  function calls are lowered.  For this reason, a variable is treated
  as not being defined (regardless of what may be on its stack) until
  a write (as from `PrimOp` or `FunctionCallOp`) to it occurs (unless
  it comes in defined, in the `defined_in` argument).

  Args:
    graph: The `ControlFlowGraph` on which to perform definedness
      analysis.
    defined_in: A Python list of `str`.  The set of variables that
      are defined on entry to this `graph`.
    liveness_map: Python `dict` mapping each `Block` in `graph` to a
      `LivenessInfo` tuple, as produced by `liveness_analysis`.

  Returns:
    defined_map: Python `dict` mapping each `Block` in `graph` to a
      `DefinedInfo` tuple.  Each of these has two fields:
      `defined_into_block` gives the `frozenset` of `str` variable
      names defined into the block, and `defined_out_instructions`
      gives a list parallel to the `Block`s instructions list, of
      variables defined out of that instruction in the block, and live
      into it.

  Raises:
    ValueError: If an invalid instruction is encountered, if a live
      variable is undefined, or if different paths into a `Block`
      cause different sets of variables to be defined.
  """
    defined_map = {}
    defined = frozenset(defined_in)

    def record_vars_defined_on_entry(block, defined):
        if block not in defined_map:
            defined_map[block] = DefinedInfo(defined, [])
        elif defined_map[block].defined_into_block != defined:
            msg = ('Inconsistent defined variable set on entry into {}.\n'
                   'Had {}, getting {}.').format(
                       block, defined_map[block].defined_into_block, defined)
            raise ValueError(msg)

    record_vars_defined_on_entry(graph.block(0), defined)

    def check_live_variables_defined(defined, live):
        for name in live:
            if name not in defined:
                raise ValueError(
                    'Detected undefined live variable {}.'.format(name))

    for i in range(graph.exit_index()):
        block = graph.block(i)
        defined = defined_map[block].defined_into_block
        check_live_variables_defined(defined,
                                     liveness_map[block].live_into_block)
        defined = defined.intersection(liveness_map[block].live_into_block)
        # Loop invariant: At this point, `defined` is the set of variables
        # that are defined and live on entry into this op.
        for op, live_out in zip(block.instructions,
                                liveness_map[block].live_out_instructions):
            if isinstance(op, (inst.PrimOp, inst.FunctionCallOp)):
                defined = defined.union(inst.pattern_flatten(op.vars_out))
            elif isinstance(op, inst.PopOp):
                defined = defined.difference(op.vars)
            else:
                raise ValueError('Invalid instruction in block {}.'.format(op))
            # At this point, `defined` is the set of variables that are
            # defined on exit from this op, and live on entry into this op.
            defined_map[block].defined_out_instructions.append(defined)
            check_live_variables_defined(defined, live_out)
            defined = defined.intersection(live_out)
            # At this point, `defined` is the set of variables that are
            # defined and live on exit from this op.
        op = block.terminator
        if isinstance(op, inst.BranchOp):
            record_vars_defined_on_entry(op.true_block, defined)
            record_vars_defined_on_entry(op.false_block, defined)
        elif isinstance(op, inst.GotoOp):
            record_vars_defined_on_entry(op.block, defined)
        elif isinstance(op, inst.PushGotoOp):
            record_vars_defined_on_entry(op.goto_block, defined)
        elif isinstance(op, inst.IndirectGotoOp):
            # Check that the return set is defined
            check_live_variables_defined(defined,
                                         liveness_map[None].live_into_block)
        else:
            raise ValueError('Invalid terminator instruction {}.'.format(op))
    return defined_map
Exemple #6
0
def _run_block(graph, index, env_dict, backend):
  """Executes or stages one basic block.

  When staging, the `graph`, the `index`, and the `backend` are static.
  The `environment` contains the runtime values (e.g. `Tensor`s or
  `np.ndarray`s) being staged.

  Args:
    graph: The `instructions.ControlFlowGraph` of blocks that exist in
      the program being run.
    index: The int index of the specific block to execute or stage.
    env_dict: The current variable environment of the VM (for all logical
      threads). The program counter is included as a variable with a
      reserved name.
    backend: Object implementing required backend operations.

  Returns:
    new_environment: The new variable environment after performing the
      block on the relevant threads.

  Raises:
    ValueError: Invalid opcode or illegal operation (e.g. branch/read/write
      program counter).
  """
  environment = inst.Environment(env_dict, backend)
  program_counter = environment.read(inst.pc_var)
  block = graph.block(index)
  logging.debug('Staging block %d:\n%s', index, block)
  mask = backend.equal(program_counter, index)
  environment[inst.pc_var] = environment.pop(inst.pc_var, mask)
  def as_index(block):
    return backend.broadcast_to_shape_of(
        graph.block_index(block), program_counter)
  for op in block.instructions:
    if isinstance(op, inst.PrimOp):
      if (inst.pc_var in inst.pattern_flatten(op.vars_in) or
          inst.pc_var in inst.pattern_flatten(op.vars_out)):
        raise ValueError(
            'PrimOp reading or writing program counter: {}'.format(op))
      inputs = [inst.pattern_map(environment.read, var_pat)
                for var_pat in op.vars_in]
      with _vm_staging():
        outputs = op.function(*inputs)
      new_vars = [(varname, environment.push(varname, output, mask))
                  for varname, output in inst.pattern_zip(op.vars_out, outputs)]
    elif isinstance(op, inst.PopOp):
      new_vars = [(varname, environment.pop(varname, mask))
                  for varname in op.vars]
    else:
      raise ValueError('Invalid instruction in block: {}'.format(type(op)))
    environment = inst.Environment(
        environment.env_dict, environment.backend, update=new_vars)
  op = block.terminator
  if isinstance(op, inst.BranchOp):
    if inst.pc_var == op.cond_var:
      raise ValueError('Branching on program counter: {}'.format(op))
    condition = environment.read(op.cond_var)
    next_index = backend.where(
        condition, as_index(op.true_block), as_index(op.false_block))
    new_vars = []
  elif isinstance(op, inst.GotoOp):
    new_vars = []
    next_index = as_index(op.block)
  elif isinstance(op, inst.PushGotoOp):
    push_index = as_index(op.push_block)
    new_vars = [(inst.pc_var, environment.push(
        inst.pc_var, push_index, mask))]
    next_index = as_index(op.goto_block)
  elif isinstance(op, inst.IndirectGotoOp):
    next_index = environment.read(inst.pc_var)
    new_vars = [(inst.pc_var, environment.pop(inst.pc_var, mask))]
  else:
    raise TypeError('Unexpected op type: {}'.format(type(op)))
  environment = inst.Environment(
      environment.env_dict, environment.backend, update=new_vars)
  environment[inst.pc_var] = environment.push(inst.pc_var, next_index, mask)
  return environment.env_dict