Beispiel #1
0
def _run_straightline(ops, backend, env_dict, mask):
    """Imperatively run a list of straight-line ops, return updated `env_dict`."""
    env = inst.Environment(env_dict, backend)
    for op in ops:
        if isinstance(op, inst.PrimOp):
            if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                    or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                raise ValueError(
                    'PrimOp reads or writes program counter: {}'.format(op))
            inputs = [
                inst.pattern_map(env.read, var_pat) for var_pat in op.vars_in
            ]
            with _stackless_running():
                outputs = op.function(*inputs)
            new_vars = [
                (varname, env.push(varname, output, mask))
                for varname, output in inst.pattern_zip(op.vars_out, outputs)
            ]
        elif isinstance(op, inst.PopOp):
            new_vars = [(varname, env.pop(varname, mask))
                        for varname in op.vars]
        else:
            raise ValueError(
                'Invalid instruction in straightline segment: {}'.format(
                    type(op)))
        env = inst.Environment(env.env_dict, env.backend, update=new_vars)
    return env.env_dict
Beispiel #2
0
 def body(env_dict, next_block_index):  # pylint:disable=missing-docstring
     # This `_staged_apply` turns into the block dispatch tree (see
     # docstring of `_staged_apply`).
     # Briefly, this will build a graph snippet for each basic block
     # in the control flow graph, and glue them together with a switch
     # on the runtime value of `next_block_index`.
     env_dict = backend.prepare_for_cond(env_dict)
     f = make_run_block_callable(env_dict)
     env_dict = backend.switch_case(
         next_block_index, [functools.partial(f, i) for i in valid_indices])
     next_block_index = _choose_next_op(inst.Environment(env_dict, backend),
                                        backend)
     return env_dict, next_block_index
Beispiel #3
0
def _interpret(program, mask, backend, block_code_cache, *inputs):
    """Worker function for `execute`; operates under a mask."""
    environment = inst.Environment.initialize(backend, program.var_alloc,
                                              program.var_defs, 0,
                                              backend.batch_size(mask))
    for var, inp in inst.pattern_zip(program.vars_in, inputs):
        environment[var] = environment.push(var, inp, mask)
    program_counter = 0  # Index of initial block
    queue = ExecutionQueue(backend)
    while program_counter != program.graph.exit_index():
        block = program.graph.block(program_counter)
        for split_idx, split in enumerate(_split_fn_calls(block.instructions)):
            if isinstance(split, inst.FunctionCallOp):
                op = split
                if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                        or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                    raise ValueError(
                        'FunctionCallOp reads or writes program counter: {}'.
                        format(op))
                inputs = [
                    inst.pattern_map(environment.read, var_pat)
                    for var_pat in op.vars_in
                ]
                # Option: Could gather+scatter at function boundaries.  Then the inner
                # interpreter would not need to accept the mask, but would need to
                # recompute the batch size and make a new mask of all ones.
                outputs = _invoke_fun(program, mask, backend, block_code_cache,
                                      op.function, inputs)
                new_vars = [(varname, environment.push(varname, output, mask))
                            for varname, output in inst.pattern_zip(
                                op.vars_out, outputs)]
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=new_vars)
            else:  # This split is not a FunctionCallOp.
                block_code_key = (id(program.graph), program_counter,
                                  split_idx)
                if block_code_key not in block_code_cache:
                    logging.vlog(1, 'Fill block cache for block %s',
                                 block_code_key)
                    varnames = inst.extract_referenced_variables(split)
                    code = backend.wrap_straightline_callable(
                        functools.partial(_run_straightline, split,
                                          environment.backend))
                    block_code_cache[block_code_key] = (varnames, code)
                else:
                    logging.vlog(1, 'Use cached code for block %s',
                                 block_code_key)
                varnames, code = block_code_cache[block_code_key]
                filtered_env = dict(
                    {  # Only pass variables relevant to these ops
                        k: v
                        for k, v in six.iteritems(environment.env_dict)
                        if k in varnames
                    })
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=code(filtered_env, mask))
        op = block.terminator
        if isinstance(op, inst.BranchOp):
            if inst.pc_var == op.cond_var:
                raise ValueError('Branching on program counter: {}'.format(op))
            condition = environment.read(op.cond_var)
            true_index = program.graph.block_index(op.true_block)
            false_index = program.graph.block_index(op.false_block)
            queue.enqueue(true_index, mask & condition)
            queue.enqueue(false_index, mask & ~condition)
        elif isinstance(op, inst.GotoOp):
            next_index = program.graph.block_index(op.block)
            queue.enqueue(next_index, mask)
        else:
            raise TypeError('Unexpected op type: {}'.format(type(op)))
        program_counter, mask = queue.dequeue()
    # assert the queue is now empty
    return inst.pattern_map(environment.read, program.vars_out)
Beispiel #4
0
def execute(program, args, max_stack_depth, backend, block_code_cache=None):
    """Executes or stages a complete auto-batching VM program.

  Whether this executes or stages computation depends on whether the backend has
  an eager or deferred computation model.

  The dimensions of the inputs and internal variables are split into
  one top batch dimension and an arbitrary number (here `E`) event
  dimensions.  The event rank may be different for different inputs,
  outputs, and internal variables.

  Args:
    program: A `instructions.Program` to execute or stage.
    args: Input values, a list of arrays, each of shape `[batch_size,
      e1, ..., eE]`.  The batch size must be the same for all inputs.
      The other dimensions must agree with the declared shapes of the
      variables they will be stored in, but need not in general be the
      same as one another.
    max_stack_depth: Python `int`. Maximum depth of stack to allocate.
    backend: Object implementing required backend operations.
    block_code_cache: Dict (allows cache to live across calls to `vm.execute`,
      or `None` (in which case a dict is created and used per call).

  Returns:
    results: A list of the output values. Each returned value is an
      array of shape `[batch_size, e1, ..., eE]`.  The results are
      returned in the same order as the variables appear in
      `program.out_vars`.
  """
    program = select_block_priority(program)
    halt_index = program.graph.exit_index()
    logging.vlog(1, 'Staging computation: %d blocks to stage', halt_index)
    valid_indices = range(halt_index)
    assert len(program.vars_in) == len(args)
    init_vals = dict(zip(program.vars_in, args))
    environment = _initialize_environment(program, init_vals, max_stack_depth,
                                          backend)
    next_block_index = _choose_next_op(environment, backend)

    if block_code_cache is None:
        block_code_cache = {}

    def make_run_block_callable(env):
        """Makes a i->next_env callable using cached, backend-wrapped _run_block."""
        def run_block_callable(i):
            if i not in block_code_cache:
                logging.vlog(1, 'Fill block code cache: block %d', i)
                block_code_cache[i] = backend.wrap_straightline_callable(
                    lambda env_arg: _run_block(program.graph, i, env_arg,
                                               backend))
            else:
                logging.vlog(1, 'Use cached block code: block %d', i)
            return block_code_cache[i](env)

        return run_block_callable

    def cond(_, next_block_index):
        return backend.not_equal(next_block_index, halt_index)

    def body(env_dict, next_block_index):  # pylint:disable=missing-docstring
        # This `_staged_apply` turns into the block dispatch tree (see
        # docstring of `_staged_apply`).
        # Briefly, this will build a graph snippet for each basic block
        # in the control flow graph, and glue them together with a switch
        # on the runtime value of `next_block_index`.
        env_dict = backend.prepare_for_cond(env_dict)
        f = make_run_block_callable(env_dict)
        env_dict = backend.switch_case(
            next_block_index, [functools.partial(f, i) for i in valid_indices])
        next_block_index = _choose_next_op(inst.Environment(env_dict, backend),
                                           backend)
        return env_dict, next_block_index

    env_dict, _ = backend.while_loop(cond, body,
                                     [environment.env_dict, next_block_index])
    final_env = inst.Environment(env_dict, backend)
    return inst.pattern_map(final_env.read, program.vars_out)
Beispiel #5
0
def _run_block(graph, index, env_dict, backend):
    """Executes or stages one basic block.

  When staging, the `graph`, the `index`, and the `backend` are static.
  The `environment` contains the runtime values (e.g. `Tensor`s or
  `np.ndarray`s) being staged.

  Args:
    graph: The `instructions.ControlFlowGraph` of blocks that exist in
      the program being run.
    index: The int index of the specific block to execute or stage.
    env_dict: The current variable environment of the VM (for all logical
      threads). The program counter is included as a variable with a
      reserved name.
    backend: Object implementing required backend operations.

  Returns:
    new_environment: The new variable environment after performing the
      block on the relevant threads.

  Raises:
    ValueError: Invalid opcode or illegal operation (e.g. branch/read/write
      program counter).
  """
    environment = inst.Environment(env_dict, backend)
    program_counter = environment.read(inst.pc_var)
    block = graph.block(index)
    logging.debug('Staging block %d:\n%s', index, block)
    mask = backend.equal(program_counter, index)

    def as_index(block):
        return backend.broadcast_to_shape_of(graph.block_index(block),
                                             program_counter)

    for op in block.instructions:
        if isinstance(op, inst.PrimOp):
            if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                    or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                raise ValueError(
                    'PrimOp reading or writing program counter: {}'.format(op))
            inputs = [
                inst.pattern_map(environment.read, var_pat)
                for var_pat in op.vars_in
            ]
            with _vm_staging():
                outputs = op.function(*inputs)

            def doit(varname, output, skip_push_mask):
                if varname in skip_push_mask:
                    return environment.update(varname, output, mask)
                else:
                    return environment.push(varname, output, mask)

            new_vars = [
                (varname, doit(varname, output, op.skip_push_mask))
                for varname, output in inst.pattern_zip(op.vars_out, outputs)
            ]
        elif isinstance(op, inst.PopOp):
            new_vars = [(varname, environment.pop(varname, mask))
                        for varname in op.vars]
        else:
            raise ValueError('Invalid instruction in block: {}'.format(
                type(op)))
        environment = inst.Environment(environment.env_dict,
                                       environment.backend,
                                       update=new_vars)
    op = block.terminator
    if isinstance(op, inst.BranchOp):
        if inst.pc_var == op.cond_var:
            raise ValueError('Branching on program counter: {}'.format(op))
        condition = environment.read(op.cond_var)
        next_index = backend.where(condition, as_index(op.true_block),
                                   as_index(op.false_block))
        environment[inst.pc_var] = environment.update(inst.pc_var, next_index,
                                                      mask)
    elif isinstance(op, inst.GotoOp):
        environment[inst.pc_var] = environment.update(inst.pc_var,
                                                      as_index(op.block), mask)
    elif isinstance(op, inst.PushGotoOp):
        environment[inst.pc_var] = environment.update(inst.pc_var,
                                                      as_index(op.push_block),
                                                      mask)
        environment[inst.pc_var] = environment.push(inst.pc_var,
                                                    as_index(op.goto_block),
                                                    mask)
    elif isinstance(op, inst.IndirectGotoOp):
        environment[inst.pc_var] = environment.pop(inst.pc_var, mask)
    else:
        raise TypeError('Unexpected op type: {}'.format(type(op)))
    return environment.env_dict