예제 #1
0
def _check_initial_dtypes(var_defs, init_vals, backend):
  var_def_dict = dict(var_defs)
  for varname, val in six.iteritems(init_vals):
    for tensor_type, subval in inst.pattern_zip(
        var_def_dict[varname].tensors, val, leaf_type=inst.TensorType):
      backend.assert_matching_dtype(
          tensor_type.dtype, subval, 'var name {}'.format(varname))
예제 #2
0
def _run_straightline(ops, backend, env_dict, mask):
    """Imperatively run a list of straight-line ops, return updated `env_dict`."""
    env = inst.Environment(env_dict, backend)
    for op in ops:
        if isinstance(op, inst.PrimOp):
            if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                    or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                raise ValueError(
                    'PrimOp reads or writes program counter: {}'.format(op))
            inputs = [
                inst.pattern_map(env.read, var_pat) for var_pat in op.vars_in
            ]
            with _stackless_running():
                outputs = op.function(*inputs)
            new_vars = [
                (varname, env.push(varname, output, mask))
                for varname, output in inst.pattern_zip(op.vars_out, outputs)
            ]
        elif isinstance(op, inst.PopOp):
            new_vars = [(varname, env.pop(varname, mask))
                        for varname in op.vars]
        else:
            raise ValueError(
                'Invalid instruction in straightline segment: {}'.format(
                    type(op)))
        env = inst.Environment(env.env_dict, env.backend, update=new_vars)
    return env.env_dict
예제 #3
0
def _function_entry_stack_ops(op, to_pop):
    """Computes a set of stack operations for the entry to a FunctionCallOp.

  The function calling convention is
  - Push the arguments to the formal parameters
  - Pop any now-dead arguments so they're not on the stack during the call
  - Jump to the beginning of the function body

  This can be a little tricky for a self-call, because then the arguments and
  the formal parameters live in the same name space and can collide.  This
  helper does something reasonable, and errors out when it can't.

  Args:
    op: FunctionCallOp instance giving the call to make stack operations for.
    to_pop: Set of names to make sure are popped before entering.

  Returns:
    ops: List of instruction objects that accomplish the goal.
  """
    push_from = []
    push_to = []
    caller_side_vars = inst.pattern_flatten(op.vars_in)
    callee_side_vars = inst.pattern_flatten(op.function.vars_in)
    for caller_side, callee_side in inst.pattern_zip(caller_side_vars,
                                                     callee_side_vars):
        if caller_side == callee_side:
            # This can happen if this is a self-call and we're just passing the
            # variable through to itself
            if caller_side in to_pop:
                # The top of the stack is already correct, and the callee will pop our
                # unneeded value off it for us: skip the push and the pop.
                to_pop = to_pop.difference([caller_side])
            else:
                # The top of the stack is correct, but we need to push it anyway because
                # the callee will (eventually) pop but we still need the current value
                # when the callee returns.
                push_from.append(caller_side)
                push_to.append(callee_side)
        elif callee_side in caller_side_vars:
            # If the graph of transfers turns out not to be a DAG, can't implement it
            # without temporary space; and don't want to bother computing a safe order
            # even if it does.
            # Fixing this is b/135275883.
            msg = ('Cannot lower FunctionCallOp that reuses its own input {}'
                   ' as a formal parameter.')
            raise ValueError(msg.format(caller_side))
        # Checking `elif caller_side in callee_side_vars` is redundant because
        # the callee_side check will trigger on that pair sooner or later.
        else:
            # Ordinary transfer: push now and then pop if needed.  The pop will not
            # interfere with the push because only caller-side variables can possibly
            # be popped.
            assert callee_side not in to_pop
            push_from.append(caller_side)
            push_to.append(callee_side)
    push_inst = inst.push_op(push_from, push_to)
    if to_pop:
        return [push_inst, inst.PopOp(list(to_pop))]
    else:
        return [push_inst]
예제 #4
0
 def assertSameTypes(self, expected_prog, typed, check_dtypes=True):
     for v, type_ in six.iteritems(typed.var_defs):
         for expected_type, got_type in instructions.pattern_zip(
                 expected_prog.var_defs[v].tensors,
                 type_.tensors,
                 leaf_type=instructions.TensorType):
             if check_dtypes:
                 self.assertEqual(expected_type.dtype, got_type.dtype)
             self.assertEqual(expected_type.shape, got_type.shape)
예제 #5
0
def _merge_vars(varnames, vm_types, inferred_types, backend, log_message=None):
    """Merges an updated vm_type for multiple variables.

  Args:
    varnames: Pattern of `string` variable names to merge.
    vm_types: Pattern of `instructions.TensorType` describing the incoming
      values.
    inferred_types: Extant dictionary of inferred types.  This is read
      to obtain the currently inferred types for the `varnames` and mutated
      to incorporate information from `vm_types`.
    backend: Object implementing required backend operations.
    log_message: Optional `string` describing this operation for the log.
  """
    if log_message is not None:
        log_debug(log_message + ': {}'.format(varnames))
    for varname, vm_type in instructions.pattern_zip(
            varnames, vm_types, leaf_type=instructions.Type):
        _merge_var(varname, vm_type, inferred_types, backend)
예제 #6
0
 def testAutoBatchingMultivalueTF(self):
   input_ = np.array([1, 1, 1], dtype=np.int64)
   output = ((np.array([1, 1, 1], dtype=np.int64),
              np.array([3, 3, 3], dtype=np.int64)),
             np.array([4, 4, 4], dtype=np.int64),
             (np.array([5, 5, 5], dtype=np.int64),
              np.array([6, 6, 6], dtype=np.int64)))
   prog = synthetic_pattern_program()
   # print(prog)
   input_t = tf.constant(input_, dtype=np.int64)
   typed = type_inference.infer_types(prog, [input_t], TF_BACKEND)
   # print(typed)
   alloc = allocation_strategy.optimize(typed)
   lowered = lowering.lower_function_calls(alloc)
   # print(lowered)
   for expected, obtained in instructions.pattern_zip(
       output, self.evaluate(_execute(lowered, input_t, 15, TF_BACKEND))):
     self.assertAllEqual(expected, obtained)
예제 #7
0
def _run_block(graph, index, env_dict, backend):
  """Executes or stages one basic block.

  When staging, the `graph`, the `index`, and the `backend` are static.
  The `environment` contains the runtime values (e.g. `Tensor`s or
  `np.ndarray`s) being staged.

  Args:
    graph: The `instructions.ControlFlowGraph` of blocks that exist in
      the program being run.
    index: The int index of the specific block to execute or stage.
    env_dict: The current variable environment of the VM (for all logical
      threads). The program counter is included as a variable with a
      reserved name.
    backend: Object implementing required backend operations.

  Returns:
    new_environment: The new variable environment after performing the
      block on the relevant threads.

  Raises:
    ValueError: Invalid opcode or illegal operation (e.g. branch/read/write
      program counter).
  """
  environment = inst.Environment(env_dict, backend)
  program_counter = environment.read(inst.pc_var)
  block = graph.block(index)
  logging.debug('Staging block %d:\n%s', index, block)
  mask = backend.equal(program_counter, index)
  def as_index(block):
    return backend.broadcast_to_shape_of(
        graph.block_index(block), program_counter)
  for op in block.instructions:
    if isinstance(op, inst.PrimOp):
      if (inst.pc_var in inst.pattern_flatten(op.vars_in) or
          inst.pc_var in inst.pattern_flatten(op.vars_out)):
        raise ValueError(
            'PrimOp reading or writing program counter: {}'.format(op))
      inputs = [inst.pattern_map(environment.read, var_pat)
                for var_pat in op.vars_in]
      with _vm_staging():
        outputs = op.function(*inputs)
      def doit(varname, output, skip_push_mask):
        if varname in skip_push_mask:
          return environment.update(varname, output, mask)
        else:
          return environment.push(varname, output, mask)
      new_vars = [(varname, doit(varname, output, op.skip_push_mask))
                  for varname, output in inst.pattern_zip(op.vars_out, outputs)]
    elif isinstance(op, inst.PopOp):
      new_vars = [(varname, environment.pop(varname, mask))
                  for varname in op.vars]
    else:
      raise ValueError('Invalid instruction in block: {}'.format(type(op)))
    environment = inst.Environment(
        environment.env_dict, environment.backend, update=new_vars)
  op = block.terminator
  if isinstance(op, inst.BranchOp):
    if inst.pc_var == op.cond_var:
      raise ValueError('Branching on program counter: {}'.format(op))
    condition = environment.read(op.cond_var)
    next_index = backend.where(
        condition, as_index(op.true_block), as_index(op.false_block))
    environment[inst.pc_var] = environment.update(inst.pc_var, next_index, mask)
  elif isinstance(op, inst.GotoOp):
    environment[inst.pc_var] = environment.update(
        inst.pc_var, as_index(op.block), mask)
  elif isinstance(op, inst.PushGotoOp):
    environment[inst.pc_var] = environment.update(
        inst.pc_var, as_index(op.push_block), mask)
    environment[inst.pc_var] = environment.push(
        inst.pc_var, as_index(op.goto_block), mask)
  elif isinstance(op, inst.IndirectGotoOp):
    environment[inst.pc_var] = environment.pop(inst.pc_var, mask)
  else:
    raise TypeError('Unexpected op type: {}'.format(type(op)))
  return environment.env_dict
예제 #8
0
def _interpret(program, mask, backend, block_code_cache, *inputs):
    """Worker function for `execute`; operates under a mask."""
    environment = inst.Environment.initialize(backend, program.var_alloc,
                                              program.var_defs, 0,
                                              backend.batch_size(mask))
    for var, inp in inst.pattern_zip(program.vars_in, inputs):
        environment[var] = environment.push(var, inp, mask)
    program_counter = 0  # Index of initial block
    queue = ExecutionQueue(backend)
    while program_counter != program.graph.exit_index():
        block = program.graph.block(program_counter)
        for split_idx, split in enumerate(_split_fn_calls(block.instructions)):
            if isinstance(split, inst.FunctionCallOp):
                op = split
                if (inst.pc_var in inst.pattern_flatten(op.vars_in)
                        or inst.pc_var in inst.pattern_flatten(op.vars_out)):
                    raise ValueError(
                        'FunctionCallOp reads or writes program counter: {}'.
                        format(op))
                inputs = [
                    inst.pattern_map(environment.read, var_pat)
                    for var_pat in op.vars_in
                ]
                # Option: Could gather+scatter at function boundaries.  Then the inner
                # interpreter would not need to accept the mask, but would need to
                # recompute the batch size and make a new mask of all ones.
                outputs = _invoke_fun(program, mask, backend, block_code_cache,
                                      op.function, inputs)
                new_vars = [(varname, environment.push(varname, output, mask))
                            for varname, output in inst.pattern_zip(
                                op.vars_out, outputs)]
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=new_vars)
            else:  # This split is not a FunctionCallOp.
                block_code_key = (id(program.graph), program_counter,
                                  split_idx)
                if block_code_key not in block_code_cache:
                    logging.info('Fill block cache for block %s',
                                 block_code_key)
                    varnames = inst.extract_referenced_variables(split)
                    code = backend.wrap_straightline_callable(
                        functools.partial(_run_straightline, split,
                                          environment.backend))
                    block_code_cache[block_code_key] = (varnames, code)
                else:
                    logging.info('Use cached code for block %s',
                                 block_code_key)
                varnames, code = block_code_cache[block_code_key]
                filtered_env = dict(
                    {  # Only pass variables relevant to these ops
                        k: v
                        for k, v in six.iteritems(environment.env_dict)
                        if k in varnames
                    })
                environment = inst.Environment(environment.env_dict,
                                               environment.backend,
                                               update=code(filtered_env, mask))
        op = block.terminator
        if isinstance(op, inst.BranchOp):
            if inst.pc_var == op.cond_var:
                raise ValueError('Branching on program counter: {}'.format(op))
            condition = environment.read(op.cond_var)
            true_index = program.graph.block_index(op.true_block)
            false_index = program.graph.block_index(op.false_block)
            queue.enqueue(true_index, mask & condition)
            queue.enqueue(false_index, mask & ~condition)
        elif isinstance(op, inst.GotoOp):
            next_index = program.graph.block_index(op.block)
            queue.enqueue(next_index, mask)
        else:
            raise TypeError('Unexpected op type: {}'.format(type(op)))
        program_counter, mask = queue.dequeue()
    # assert the queue is now empty
    return inst.pattern_map(environment.read, program.vars_out)