Exemplo n.º 1
0
    def if_(self,
            condition,
            then_name=None,
            else_name=None,
            continue_name=None):
        """Records a conditional operation and `true` first branch.

    The `false` branch, if present, must be guarded by a call to `else_`, below.

    Example:
    ```python
    ab = dsl.ProgramBuilder()

    ab.var.true = ab.const(True)
    with ab.if_(ab.var.true):
      ...  # The body of the `with` statement gives the `true` branch
    with ab.else_():  # The else_ clause is optional
      ...
    ```

    Args:
      condition: Python string giving the boolean variable that holds the branch
        condition.
      then_name: Optional Python string naming the true branch when the program
        is printed.
      else_name: Optional Python string naming the false branch when the program
        is printed.
      continue_name: Optional Python string naming the continuation after the if
        when the program is printed.

    Yields:
      Nothing.

    Raises:
      ValueError: If trying to condition on a variable that has not been
        written to.
    """
        # Should I have _prepare_for_instruction here?
        self._prepare_for_instruction()
        if condition not in self._var_defs:
            raise ValueError(
                'Undefined variable {} used as if condition.'.format(
                    condition))
        then_block = self._fresh_block(name=then_name)
        else_block = self._fresh_block(name=else_name)
        after_else_block = self._fresh_block(name=continue_name)
        self._append_block(then_block,
                           prev_terminator=inst.BranchOp(
                               str(condition), then_block, else_block))
        yield
        # In case the enclosed code ended in a dangling if, close it
        self._prepare_for_instruction()
        # FIXME: Always adding this goto risks polluting the output with
        # excess gotos.  They can probably be cleaned up during label
        # resolution.
        self._append_block(else_block,
                           prev_terminator=inst.GotoOp(after_else_block))
        self._pending_after_else_block = after_else_block
Exemplo n.º 2
0
 def _append_block(self, block, prev_terminator=None):
     if prev_terminator is None:
         prev_terminator = inst.GotoOp(block)
     assert self._blocks[
         -1].terminator is None, 'Internal invariant violation'
     assert block.instructions is not None, 'Internal invariant violation'
     assert block.terminator is None, 'Internal invariant violation'
     self._blocks[-1].terminator = prev_terminator
     self._blocks.append(block)
Exemplo n.º 3
0
    def end_block_with_tail_call(self, target):
        """End the current block with a jump to the given block.

    The terminator of the current block becomes a `GotoOp` to the target.
    No new block is created (as it would be in `split_block`), because
    by assumption there are no additional instructions to return to.

    Args:
      target: The block to jump to.
    """
        self.cur_block().terminator = inst.GotoOp(target)
Exemplo n.º 4
0
  def end_block_with_tail_call(self, target):
    """End the current block with a jump to the given block.

    The terminator of the current block becomes a `GotoOp` to the target.
    No new block is created (as it would be in `split_block`), because
    by assumption there are no additional instructions to return to.

    Args:
      target: The block to jump to.
    """
    self.cur_block().terminator = inst.GotoOp(target)
    # Insurance against having only 2 switch arms, which triggers a bug in
    # tensorflow/compiler/jit/deadness_analysis.
    self.append_block(inst.Block(instructions=[], terminator=inst.halt_op()))
Exemplo n.º 5
0
def pea_nuts_program(latent_shape, choose_depth, step_state):
    """Synthetic program usable for benchmarking VM performance.

  This program is intended to resemble the control flow and scaling
  parameters of the NUTS algorithm, without any of the complexity.
  Hence the name.

  Each batch member looks like:

    state = ... # shape latent_shape

    def recur(depth, state):
      if depth > 1:
        state1 = recur(depth - 1, state)
        state2 = state1 + 1
        state3 = recur(depth - 1, state2)
        ans = state3 + 1
      else:
        ans = step_state(state)  # To simulate NUTS, something heavy
      return ans

    while count > 0:
      count = count - 1
      depth = choose_depth(count)
      state = recur(depth, state)

  Args:
    latent_shape: Python `tuple` of `int` giving the event shape of the
      latent state.
    choose_depth: Python `Tensor -> Tensor` callable.  The input
      `Tensor` will have shape `[batch_size]` (i.e., scalar event
      shape), and give the iteration of the outer while loop the
      thread is in.  The `choose_depth` function must return a `Tensor`
      of shape `[batch_size]` giving the depth, for each thread,
      to which to call `recur` in this iteration.
    step_state: Python `Tensor -> Tensor` callable.  The input and
      output `Tensor`s will have shape `[batch_size] + latent_shape`.
      This function is expected to update the state, and represents
      the "real work" versus which the VM overhead is being measured.

  Returns:
    program: `instructions.Program` that runs the above benchmark.
  """
    entry = instructions.Block()
    top_body = instructions.Block()
    finish_body = instructions.Block()
    enter_recur = instructions.Block()
    recur_body_1 = instructions.Block()
    recur_body_2 = instructions.Block()
    recur_body_3 = instructions.Block()
    recur_base_case = instructions.Block()
    # pylint: disable=bad-whitespace
    entry.assign_instructions([
        instructions.PrimOp(["count"], "cond",
                            lambda count: count > 0),  # cond = count > 0
        instructions.BranchOp("cond", top_body,
                              instructions.halt()),  # if cond
    ])
    top_body.assign_instructions([
        instructions.PopOp(["cond"]),  #   done with cond now
        instructions.PrimOp(["count"], "ctm1",
                            lambda count: count - 1),  #   ctm1 = count - 1
        instructions.PopOp(["count"]),  #   done with count now
        instructions.push_op(["ctm1"], ["count"]),  #   count = ctm1
        instructions.PopOp(["ctm1"]),  #   done with ctm1
        instructions.PrimOp(["count"], "depth",
                            choose_depth),  #   depth = choose_depth(count)
        instructions.push_op(
            ["depth", "state"],
            ["depth", "state"]),  #   state = recur(depth, state)
        instructions.PopOp(["depth", "state"]),  #     done with depth, state
        instructions.PushGotoOp(finish_body, enter_recur),
    ])
    finish_body.assign_instructions([
        instructions.push_op(["ans"], ["state"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.GotoOp(entry),  # end of while body
    ])
    # Definition of recur begins here
    enter_recur.assign_instructions([
        instructions.PrimOp(["depth"], "cond1",
                            lambda depth: depth > 0),  # cond1 = depth > 0
        instructions.BranchOp("cond1", recur_body_1,
                              recur_base_case),  # if cond1
    ])
    recur_body_1.assign_instructions([
        instructions.PopOp(["cond1"]),  #   done with cond1 now
        instructions.PrimOp(["depth"], "dm1",
                            lambda depth: depth - 1),  #   dm1 = depth - 1
        instructions.PopOp(["depth"]),  #   done with depth
        instructions.push_op(
            ["dm1", "state"],
            ["depth", "state"]),  #   state1 = recur(dm1, state)
        instructions.PopOp(["state"]),  #     done with state
        instructions.PushGotoOp(recur_body_2, enter_recur),
    ])
    recur_body_2.assign_instructions([
        instructions.push_op(["ans"], ["state1"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.PrimOp(["state1"], "state2",
                            lambda state: state + 1),  #   state2 = state1 + 1
        instructions.PopOp(["state1"]),  #   done with state1
        instructions.push_op(
            ["dm1", "state2"],
            ["depth", "state"]),  #   state3 = recur(dm1, state2)
        instructions.PopOp(["dm1", "state2"]),  #     done with dm1, state2
        instructions.PushGotoOp(recur_body_3, enter_recur),
    ])
    recur_body_3.assign_instructions([
        instructions.push_op(["ans"], ["state3"]),  #     ...
        instructions.PopOp(["ans"]),  #     pop callee's "ans"
        instructions.PrimOp(["state3"], "ans",
                            lambda state: state + 1),  #   ans = state3 + 1
        instructions.PopOp(["state3"]),  #   done with state3
        instructions.IndirectGotoOp(),  #   return ans
    ])
    recur_base_case.assign_instructions([
        instructions.PopOp(["cond1", "depth"]),  #   done with cond1, depth
        instructions.PrimOp(["state"], "ans",
                            step_state),  #   ans = step_state(state)
        instructions.PopOp(["state"]),  #   done with state
        instructions.IndirectGotoOp(),  #   return ans
    ])

    pea_nuts_graph = instructions.ControlFlowGraph([
        entry,
        top_body,
        finish_body,
        enter_recur,
        recur_body_1,
        recur_body_2,
        recur_body_3,
        recur_base_case,
    ])

    # pylint: disable=bad-whitespace
    pea_nuts_vars = {
        "count": instructions.single_type(np.int64, ()),
        "cond": instructions.single_type(np.bool, ()),
        "cond1": instructions.single_type(np.bool, ()),
        "ctm1": instructions.single_type(np.int64, ()),
        "depth": instructions.single_type(np.int64, ()),
        "dm1": instructions.single_type(np.int64, ()),
        "state": instructions.single_type(np.float32, latent_shape),
        "state1": instructions.single_type(np.float32, latent_shape),
        "state2": instructions.single_type(np.float32, latent_shape),
        "state3": instructions.single_type(np.float32, latent_shape),
        "ans": instructions.single_type(np.float32, latent_shape),
    }

    return instructions.Program(pea_nuts_graph, [], pea_nuts_vars,
                                ["count", "state"], "state")