Esempio n. 1
0
def _run_straightline(ops, backend, env_dict, mask):
    """Imperatively run a list of straight-line ops, return updated `env_dict`."""
    env = inst.Environment(env_dict, backend)
    for op in ops:
        if isinstance(op, inst.PrimOp):
            if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                    or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                raise ValueError(
                    'PrimOp reads or writes program counter: {}'.format(op))
            inputs = [
                inst.pattern_map(env.read, var_pat) for var_pat in op.vars_in
            ]
            with _stackless_running():
                outputs = op.function(*inputs)
            new_vars = [
                (varname, env.push(varname, output, mask))
                for varname, output in inst.pattern_zip(op.vars_out, outputs)
            ]
        elif isinstance(op, inst.PopOp):
            new_vars = [(varname, env.pop(varname, mask))
                        for varname in op.vars]
        else:
            raise ValueError(
                'Invalid instruction in straightline segment: {}'.format(
                    type(op)))
        env = inst.Environment(env.env_dict, env.backend, update=new_vars)
    return env.env_dict
Esempio n. 2
0
 def add_batch_dim_one_var(type_):
     return instructions.Type(
         instructions.pattern_map(
             lambda t: instructions.TensorType(t.dtype,
                                               (new_batch_dim, ) + t.shape),
             type_.tensors,
             leaf_type=instructions.TensorType))
Esempio n. 3
0
  def run_on_dummies(self, primitive_callable, input_types):
    """Runs the given `primitive_callable` with dummy input.

    This is useful for examining the outputs for the purpose of type inference.

    Args:
      primitive_callable: A python callable.
      input_types: `list` of `instructions.Type` type of each argument to the
        callable.  Note that the contained `TensorType` objects must match the
        dimensions with which the primitive is to be invoked at runtime, even
        though type inference conventionally does not store the batch dimension
        in the `TensorType`s.

    Returns:
      outputs: pattern of backend-specific objects whose types may be
        analyzed by the caller with `type_of`.
    """
    with tf.name_scope('VM.run_on_dummies'):
      # We cannot use a temporary graph in eager mode because user code may
      # close over eager tensors, causing `RuntimeError: Attempting to capture
      # an EagerTensor without building a function.`
      # In graph mode, capturing user Tensors has also been a problem, because
      # TF doesn't like the inputs of an op being in different graphs.
      # Status quo is unfortunate because it involves running the computation
      # in the primop to determine its shape behavior, instead of just invoking
      # shape inference.
      # There may be a solution involving FuncGraph; see b/118896442.
      def mk_placeholder(vt):
        return tf.ones(vt.shape, dtype=vt.dtype)
      phs = [
          instructions.pattern_map(
              mk_placeholder, vtype.tensors, leaf_type=instructions.TensorType)
          for vtype in input_types]
      return primitive_callable(*phs)
Esempio n. 4
0
    def run_on_dummies(self, primitive_callable, input_types):
        """Runs the given `primitive_callable` with dummy input.

    This is useful for examining the outputs for the purpose of type inference.

    Args:
      primitive_callable: A python callable.
      input_types: `list` of `instructions.Type` type of each argument to the
        callable.  Note that the contained `TensorType` objects must match the
        dimensions with which the primitive is to be invoked at runtime, even
        though type inference conventionally does not store the batch dimension
        in the `TensorType`s.

    Returns:
      outputs: pattern of backend-specific objects whose types may be
        analyzed by the caller with `type_of`.
    """
        def at_tensor(vt):
            return np.zeros(vt.shape, dtype=vt.dtype)

        inputs = [
            instructions.pattern_map(at_tensor,
                                     type_.tensors,
                                     leaf_type=instructions.TensorType)
            for type_ in input_types
        ]
        return primitive_callable(*inputs)
Esempio n. 5
0
 def pattern(self, item):
     # _MagicPattern is meant to be a friend class of ProgramBuilder.
     # pylint: disable=protected-access
     if item is not None:
         self._context._update_last_instruction(
             item.replace(vars_out=inst.pattern_map(str, self._pattern)))
     for var in inst.pattern_traverse(self._pattern):
         self._context._mark_defined(var)
Esempio n. 6
0
 def _init_f(env_dict, *args):
   """A RegisterTensorFlowVariable-initializing wrapper of `_f`."""
   # We ensure RegisterTensorFlowVariable instances have a Tensor value when
   # using XLA and/or defun. Otherwise, we will trigger cache misses on the
   # tfe.defun or get issues around "Cannot convert object of type [dtype] to
   # a Tensor" (XLA). This corresponds with the optimization in
   # `create_variable` conditioned on Eager & VariableAllocation.REGISTER.
   env_dict = dict({k: instructions.pattern_map(
       _ensure_regvars_initialized, v, leaf_type=RegisterTensorFlowVariable)
                    for k, v in six.iteritems(env_dict)})
   return _f(env_dict, *args)
Esempio n. 7
0
 def at_leaf(preferred_leaf_type, obj):
     """Pattern match at a leaf of the preferred_type pattern."""
     if preferred_leaf_type is None:
         return instructions.pattern_map(backend.type_of, obj)
     if isinstance(preferred_leaf_type, instructions.TensorType):
         return backend.type_of(obj, preferred_leaf_type.dtype)
     # Otherwise, preferred_leaf_type must be a (nested) list or tuple of
     # TensorType, while obj is not a list or a tuple (of anything).  In this
     # case, pattern_map2 should have raised an error, but we can defensively
     # raise an error here as well.
     msg = 'Type mismatch: Expected structured type {}, got object {}.'.format(
         preferred_leaf_type, obj)
     raise ValueError(msg)
Esempio n. 8
0
def _add_incompatible_batch_dim(type_pat):
  """Adds a batch dim incompatible with all other known dims."""
  new_batch_dim = 2
  for tp in instructions.pattern_traverse(
      type_pat, leaf_type=instructions.TensorType):
    new_batch_dim = max(new_batch_dim, max((0,) + tp.shape) + 1)
  log_debug('using incompatible batch dim %d', new_batch_dim)
  def add_batch_dim_one_var(type_):
    return instructions.Type(instructions.pattern_map(
        lambda t: instructions.TensorType(t.dtype, (new_batch_dim,) + t.shape),
        type_.tensors, leaf_type=instructions.TensorType))
  return instructions.pattern_map(
      add_batch_dim_one_var, type_pat, leaf_type=instructions.Type)
Esempio n. 9
0
    def call(self, function, vars_in, vars_out=None):
        """Registers a function call instruction.

    Example:
    ```
    ab = dsl.ProgramBuilder()

    # Define a function
    with ab.function(...) as func:
      ...
      # Call it (recursively)
      ab.var.thing = ab.call(func, ...)
      ...
    ```

    Args:
      function: The `instructions.Function` object representing the function to
        call.
      vars_in: Python strings giving the variables to pass in as inputs.
      vars_out: A pattern of Python strings, giving the auto-batched variable(s)
        to which to write the result of the call.  Defaults to the empty list.

    Raises:
      ValueError: If the call references undefined auto-batched variables.

    Returns:
      op: An `instructions.FunctionCallOp` representing the call.  If one
        subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`,
        that local gets added to the list of output variables.
    """
        for var in vars_in:
            if var not in self._var_defs:
                raise ValueError(
                    'Referencing undefined variable {}.'.format(var))
        self._prepare_for_instruction()
        if vars_out is None:
            vars_out = []
        call = inst.FunctionCallOp(function, _str_list(vars_in),
                                   inst.pattern_map(str, vars_out))
        self._blocks[-1].instructions.append(call)
        for var in inst.pattern_traverse(vars_out):
            self._mark_defined(var)
        return call
Esempio n. 10
0
def _process_block(block, visited, inferred_types, backend):
  """Executes a pass of type inference on a single `Block`."""
  for op in block.instructions:
    log_debug('handle op {}'.format(op))
    if isinstance(op, instructions.PrimOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      types_in = [inferred_types[var] for var in op.vars_in]
      # Offer type hints for cases where we need to type non-Tensor literals.
      preferred_types_out = instructions.pattern_map(
          lambda var: inferred_types[var], op.vars_out)
      with _type_inferring():
        objs_out = backend.run_on_dummies(
            op.function, _add_incompatible_batch_dim(types_in))
      types_out = _strip_batch_dim(instructions.pattern_map2(
          lambda tp, val: type_of_pattern(val, backend, preferred_type=tp),
          preferred_types_out, objs_out, leaf_type=instructions.Type))
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update PrimOp vars_out')
    elif isinstance(op, instructions.FunctionCallOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      # First, bind op.vars_in to op.function.vars_in.
      types_in = [inferred_types[var].tensors for var in op.vars_in]
      _merge_vars(op.function.vars_in, types_in, inferred_types, backend,
                  log_message='init function vars_in')
      # Execute type inference.
      types_out = op.function.type_inference(types_in)
      for leaf in instructions.pattern_traverse(
          types_out, leaf_type=instructions.TensorType):
        if not isinstance(leaf, instructions.TensorType):
          msg = ('Expected function output type to be '
                 'a nested list or tuple of TensorType, found {}.').format(leaf)
          raise TypeError(msg)
      # To help with typing recursive base-case return literals, we seed
      # return_vars types before stepping into the function.
      _merge_vars(op.function.vars_out, types_out, inferred_types, backend,
                  log_message='update function vars_out')
      # Finally, update op.vars_out with the results of type inference.
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update FunctionCall vars_out')
      # Step into function. Note: it will only be visited once, if recursive.
      _process_graph(op.function.graph, visited, inferred_types, backend)
Esempio n. 11
0
    def return_(self, vars_out):
        """Records a function return instruction.

    Example:
    ```python
    ab = dsl.ProgramBuilder()

    with ab.function(...) as f:
      ...
      ab.var.result = ...
      ab.return_(ab.var.result)
    ```

    A `return_` command must occur at the top level of the function definition
    (not inside any `if_`s), and must be the last statement therein.  You can
    always achieve this by assigning to a dedicated variable for the answer
    where you would otherwise return (and massaging your control flow).

    Args:
      vars_out: Pattern of Python strings giving the auto-batched variables to
        return.

    Raises:
      ValueError: If invoked more than once in a function body, or if trying to
        return variables that have not been written to.
    """
        # Assume the return_ call is at the top level, and the last statement in the
        # body.  If return_ is nested, the terminator may be overwritten
        # incorrectly.  If return_ is followed by something else, extra instructions
        # may get inserted before the return (becaue return_ doesn't set up a Block
        # to catch them).
        self._prepare_for_instruction()
        for var in inst.pattern_traverse(vars_out):
            if var not in self._var_defs:
                raise ValueError(
                    'Returning undefined variable {}.'.format(var))
        if self._functions[-1].vars_out:
            raise ValueError(
                'Function body must have exactly one return_ statement')
        self._functions[-1].vars_out = inst.pattern_map(str, vars_out)
        self._blocks[-1].terminator = inst.halt_op()
Esempio n. 12
0
def _strip_batch_dim(type_):
    return instructions.pattern_map(
        lambda t: instructions.TensorType(t.dtype, t.shape[1:]),
        type_,
        leaf_type=instructions.TensorType)
Esempio n. 13
0
def _interpret(program, mask, backend, block_code_cache, *inputs):
    """Worker function for `execute`; operates under a mask."""
    environment = inst.Environment.initialize(backend, program.var_alloc,
                                              program.var_defs, 0,
                                              backend.batch_size(mask))
    for var, inp in inst.pattern_zip(program.vars_in, inputs):
        environment[var] = environment.push(var, inp, mask)
    program_counter = 0  # Index of initial block
    queue = ExecutionQueue(backend)
    while program_counter != program.graph.exit_index():
        block = program.graph.block(program_counter)
        for split_idx, split in enumerate(_split_fn_calls(block.instructions)):
            if isinstance(split, inst.FunctionCallOp):
                op = split
                if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                        or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                    raise ValueError(
                        'FunctionCallOp reads or writes program counter: {}'.
                        format(op))
                inputs = [
                    inst.pattern_map(environment.read, var_pat)
                    for var_pat in op.vars_in
                ]
                # Option: Could gather+scatter at function boundaries.  Then the inner
                # interpreter would not need to accept the mask, but would need to
                # recompute the batch size and make a new mask of all ones.
                outputs = _invoke_fun(program, mask, backend, block_code_cache,
                                      op.function, inputs)
                new_vars = [(varname, environment.push(varname, output, mask))
                            for varname, output in inst.pattern_zip(
                                op.vars_out, outputs)]
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=new_vars)
            else:  # This split is not a FunctionCallOp.
                block_code_key = (id(program.graph), program_counter,
                                  split_idx)
                if block_code_key not in block_code_cache:
                    logging.vlog(1, 'Fill block cache for block %s',
                                 block_code_key)
                    varnames = inst.extract_referenced_variables(split)
                    code = backend.wrap_straightline_callable(
                        functools.partial(_run_straightline, split,
                                          environment.backend))
                    block_code_cache[block_code_key] = (varnames, code)
                else:
                    logging.vlog(1, 'Use cached code for block %s',
                                 block_code_key)
                varnames, code = block_code_cache[block_code_key]
                filtered_env = dict(
                    {  # Only pass variables relevant to these ops
                        k: v
                        for k, v in six.iteritems(environment.env_dict)
                        if k in varnames
                    })
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=code(filtered_env, mask))
        op = block.terminator
        if isinstance(op, inst.BranchOp):
            if inst.pc_var == op.cond_var:
                raise ValueError('Branching on program counter: {}'.format(op))
            condition = environment.read(op.cond_var)
            true_index = program.graph.block_index(op.true_block)
            false_index = program.graph.block_index(op.false_block)
            queue.enqueue(true_index, mask & condition)
            queue.enqueue(false_index, mask & ~condition)
        elif isinstance(op, inst.GotoOp):
            next_index = program.graph.block_index(op.block)
            queue.enqueue(next_index, mask)
        else:
            raise TypeError('Unexpected op type: {}'.format(type(op)))
        program_counter, mask = queue.dequeue()
    # assert the queue is now empty
    return inst.pattern_map(environment.read, program.vars_out)
Esempio n. 14
0
    def primop(self, f, vars_in=None, vars_out=None):
        """Records a primitive operation.

    Example:

    ```
    ab = dsl.ProgramBuilder()

    ab.var.five = ab.const(5)
    # Implicit output binding
    ab.var.ten = ab.primop(lambda five: five + five)
    # Explicit output binding
    ab.primop(lambda: (5, 10), vars_out=[ab.var.five, ab.var.ten])
    ```

    Args:
      f: A Python callable, the primitive operation to perform.  Can be
        an inline lambda expression in simple cases.  Must return a list or
        tuple of results, one for each intended output variable.
      vars_in: A list of Python strings, giving the auto-batched variables
        to pass into the callable when invoking it.  If absent, `primop`
        will try to infer it by inspecting the argument list of the callable
        and matching against variables bound in the local scope.
      vars_out: A pattern of Python strings, giving the auto-batched variable(s)
        to which to write the result of the callable.  Defaults to the empty
        list.

    Raises:
      ValueError: If the definition is invalid, if the primop references
        undefined auto-batched variables, or if auto-detection of input
        variables fails.

    Returns:
      op: An `instructions.PrimOp` instance representing this operation.  If one
        subsequently assigns this to a local, via `ProgramBuilder.var.foo = op`,
        that local becomes the output pattern.
    """
        self._prepare_for_instruction()
        if vars_out is None:
            vars_out = []
        if vars_in is None:
            # Deduce the intended variable names from the argument list of the callee.
            # Expected use case: the callee is an inline lambda expression.
            args, varargs, keywords, _ = inspect.getargspec(f)
            vars_in = []
            for arg in args:
                if arg in self._locals:
                    vars_in.append(self._locals[arg])
                else:
                    raise ValueError(
                        'Auto-referencing unbound variable {}.'.format(arg))
            if varargs is not None:
                raise ValueError('Varargs are not supported for primops')
            if keywords is not None:
                raise ValueError('kwargs are not supported for primops')
        for var in vars_in:
            if var not in self._var_defs:
                raise ValueError(
                    'Referencing undefined variable {}.'.format(var))
        prim = inst.prim_op(_str_list(vars_in),
                            inst.pattern_map(str, vars_out), f)
        self._blocks[-1].instructions.append(prim)
        for var in inst.pattern_traverse(vars_out):
            self._mark_defined(var)
        return prim
Esempio n. 15
0
def execute(program, args, max_stack_depth, backend, block_code_cache=None):
    """Executes or stages a complete auto-batching VM program.

  Whether this executes or stages computation depends on whether the backend has
  an eager or deferred computation model.

  The dimensions of the inputs and internal variables are split into
  one top batch dimension and an arbitrary number (here `E`) event
  dimensions.  The event rank may be different for different inputs,
  outputs, and internal variables.

  Args:
    program: A `instructions.Program` to execute or stage.
    args: Input values, a list of arrays, each of shape `[batch_size,
      e1, ..., eE]`.  The batch size must be the same for all inputs.
      The other dimensions must agree with the declared shapes of the
      variables they will be stored in, but need not in general be the
      same as one another.
    max_stack_depth: Python `int`. Maximum depth of stack to allocate.
    backend: Object implementing required backend operations.
    block_code_cache: Dict (allows cache to live across calls to `vm.execute`,
      or `None` (in which case a dict is created and used per call).

  Returns:
    results: A list of the output values. Each returned value is an
      array of shape `[batch_size, e1, ..., eE]`.  The results are
      returned in the same order as the variables appear in
      `program.out_vars`.
  """
    program = select_block_priority(program)
    halt_index = program.graph.exit_index()
    logging.vlog(1, 'Staging computation: %d blocks to stage', halt_index)
    valid_indices = range(halt_index)
    assert len(program.vars_in) == len(args)
    init_vals = dict(zip(program.vars_in, args))
    environment = _initialize_environment(program, init_vals, max_stack_depth,
                                          backend)
    next_block_index = _choose_next_op(environment, backend)

    if block_code_cache is None:
        block_code_cache = {}

    def make_run_block_callable(env):
        """Makes a i->next_env callable using cached, backend-wrapped _run_block."""
        def run_block_callable(i):
            if i not in block_code_cache:
                logging.vlog(1, 'Fill block code cache: block %d', i)
                block_code_cache[i] = backend.wrap_straightline_callable(
                    lambda env_arg: _run_block(program.graph, i, env_arg,
                                               backend))
            else:
                logging.vlog(1, 'Use cached block code: block %d', i)
            return block_code_cache[i](env)

        return run_block_callable

    def cond(_, next_block_index):
        return backend.not_equal(next_block_index, halt_index)

    def body(env_dict, next_block_index):  # pylint:disable=missing-docstring
        # This `_staged_apply` turns into the block dispatch tree (see
        # docstring of `_staged_apply`).
        # Briefly, this will build a graph snippet for each basic block
        # in the control flow graph, and glue them together with a switch
        # on the runtime value of `next_block_index`.
        env_dict = backend.prepare_for_cond(env_dict)
        f = make_run_block_callable(env_dict)
        env_dict = backend.switch_case(
            next_block_index, [functools.partial(f, i) for i in valid_indices])
        next_block_index = _choose_next_op(inst.Environment(env_dict, backend),
                                           backend)
        return env_dict, next_block_index

    env_dict, _ = backend.while_loop(cond, body,
                                     [environment.env_dict, next_block_index])
    final_env = inst.Environment(env_dict, backend)
    return inst.pattern_map(final_env.read, program.vars_out)
Esempio n. 16
0
def _run_block(graph, index, env_dict, backend):
    """Executes or stages one basic block.

  When staging, the `graph`, the `index`, and the `backend` are static.
  The `environment` contains the runtime values (e.g. `Tensor`s or
  `np.ndarray`s) being staged.

  Args:
    graph: The `instructions.ControlFlowGraph` of blocks that exist in
      the program being run.
    index: The int index of the specific block to execute or stage.
    env_dict: The current variable environment of the VM (for all logical
      threads). The program counter is included as a variable with a
      reserved name.
    backend: Object implementing required backend operations.

  Returns:
    new_environment: The new variable environment after performing the
      block on the relevant threads.

  Raises:
    ValueError: Invalid opcode or illegal operation (e.g. branch/read/write
      program counter).
  """
    environment = inst.Environment(env_dict, backend)
    program_counter = environment.read(inst.pc_var)
    block = graph.block(index)
    logging.debug('Staging block %d:\n%s', index, block)
    mask = backend.equal(program_counter, index)

    def as_index(block):
        return backend.broadcast_to_shape_of(graph.block_index(block),
                                             program_counter)

    for op in block.instructions:
        if isinstance(op, inst.PrimOp):
            if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                    or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                raise ValueError(
                    'PrimOp reading or writing program counter: {}'.format(op))
            inputs = [
                inst.pattern_map(environment.read, var_pat)
                for var_pat in op.vars_in
            ]
            with _vm_staging():
                outputs = op.function(*inputs)

            def doit(varname, output, skip_push_mask):
                if varname in skip_push_mask:
                    return environment.update(varname, output, mask)
                else:
                    return environment.push(varname, output, mask)

            new_vars = [
                (varname, doit(varname, output, op.skip_push_mask))
                for varname, output in inst.pattern_zip(op.vars_out, outputs)
            ]
        elif isinstance(op, inst.PopOp):
            new_vars = [(varname, environment.pop(varname, mask))
                        for varname in op.vars]
        else:
            raise ValueError('Invalid instruction in block: {}'.format(
                type(op)))
        environment = inst.Environment(environment.env_dict,
                                       environment.backend,
                                       update=new_vars)
    op = block.terminator
    if isinstance(op, inst.BranchOp):
        if inst.pc_var == op.cond_var:
            raise ValueError('Branching on program counter: {}'.format(op))
        condition = environment.read(op.cond_var)
        next_index = backend.where(condition, as_index(op.true_block),
                                   as_index(op.false_block))
        environment[inst.pc_var] = environment.update(inst.pc_var, next_index,
                                                      mask)
    elif isinstance(op, inst.GotoOp):
        environment[inst.pc_var] = environment.update(inst.pc_var,
                                                      as_index(op.block), mask)
    elif isinstance(op, inst.PushGotoOp):
        environment[inst.pc_var] = environment.update(inst.pc_var,
                                                      as_index(op.push_block),
                                                      mask)
        environment[inst.pc_var] = environment.push(inst.pc_var,
                                                    as_index(op.goto_block),
                                                    mask)
    elif isinstance(op, inst.IndirectGotoOp):
        environment[inst.pc_var] = environment.pop(inst.pc_var, mask)
    else:
        raise TypeError('Unexpected op type: {}'.format(type(op)))
    return environment.env_dict