def _check_initial_dtypes(var_defs, init_vals, backend): var_def_dict = dict(var_defs) for varname, val in six.iteritems(init_vals): for tensor_type, subval in inst.pattern_zip( var_def_dict[varname].tensors, val, leaf_type=inst.TensorType): backend.assert_matching_dtype( tensor_type.dtype, subval, 'var name {}'.format(varname))
def _run_straightline(ops, backend, env_dict, mask): """Imperatively run a list of straight-line ops, return updated `env_dict`.""" env = inst.Environment(env_dict, backend) for op in ops: if isinstance(op, inst.PrimOp): if (inst.pc_var in inst.pattern_flatten(op.vars_in) or inst.pc_var in inst.pattern_flatten(op.vars_out)): raise ValueError( 'PrimOp reads or writes program counter: {}'.format(op)) inputs = [ inst.pattern_map(env.read, var_pat) for var_pat in op.vars_in ] with _stackless_running(): outputs = op.function(*inputs) new_vars = [ (varname, env.push(varname, output, mask)) for varname, output in inst.pattern_zip(op.vars_out, outputs) ] elif isinstance(op, inst.PopOp): new_vars = [(varname, env.pop(varname, mask)) for varname in op.vars] else: raise ValueError( 'Invalid instruction in straightline segment: {}'.format( type(op))) env = inst.Environment(env.env_dict, env.backend, update=new_vars) return env.env_dict
def _function_entry_stack_ops(op, to_pop): """Computes a set of stack operations for the entry to a FunctionCallOp. The function calling convention is - Push the arguments to the formal parameters - Pop any now-dead arguments so they're not on the stack during the call - Jump to the beginning of the function body This can be a little tricky for a self-call, because then the arguments and the formal parameters live in the same name space and can collide. This helper does something reasonable, and errors out when it can't. Args: op: FunctionCallOp instance giving the call to make stack operations for. to_pop: Set of names to make sure are popped before entering. Returns: ops: List of instruction objects that accomplish the goal. """ push_from = [] push_to = [] caller_side_vars = inst.pattern_flatten(op.vars_in) callee_side_vars = inst.pattern_flatten(op.function.vars_in) for caller_side, callee_side in inst.pattern_zip(caller_side_vars, callee_side_vars): if caller_side == callee_side: # This can happen if this is a self-call and we're just passing the # variable through to itself if caller_side in to_pop: # The top of the stack is already correct, and the callee will pop our # unneeded value off it for us: skip the push and the pop. to_pop = to_pop.difference([caller_side]) else: # The top of the stack is correct, but we need to push it anyway because # the callee will (eventually) pop but we still need the current value # when the callee returns. push_from.append(caller_side) push_to.append(callee_side) elif callee_side in caller_side_vars: # If the graph of transfers turns out not to be a DAG, can't implement it # without temporary space; and don't want to bother computing a safe order # even if it does. # Fixing this is b/135275883. msg = ('Cannot lower FunctionCallOp that reuses its own input {}' ' as a formal parameter.') raise ValueError(msg.format(caller_side)) # Checking `elif caller_side in callee_side_vars` is redundant because # the callee_side check will trigger on that pair sooner or later. else: # Ordinary transfer: push now and then pop if needed. The pop will not # interfere with the push because only caller-side variables can possibly # be popped. assert callee_side not in to_pop push_from.append(caller_side) push_to.append(callee_side) push_inst = inst.push_op(push_from, push_to) if to_pop: return [push_inst, inst.PopOp(list(to_pop))] else: return [push_inst]
def assertSameTypes(self, expected_prog, typed, check_dtypes=True): for v, type_ in six.iteritems(typed.var_defs): for expected_type, got_type in instructions.pattern_zip( expected_prog.var_defs[v].tensors, type_.tensors, leaf_type=instructions.TensorType): if check_dtypes: self.assertEqual(expected_type.dtype, got_type.dtype) self.assertEqual(expected_type.shape, got_type.shape)
def _merge_vars(varnames, vm_types, inferred_types, backend, log_message=None): """Merges an updated vm_type for multiple variables. Args: varnames: Pattern of `string` variable names to merge. vm_types: Pattern of `instructions.TensorType` describing the incoming values. inferred_types: Extant dictionary of inferred types. This is read to obtain the currently inferred types for the `varnames` and mutated to incorporate information from `vm_types`. backend: Object implementing required backend operations. log_message: Optional `string` describing this operation for the log. """ if log_message is not None: log_debug(log_message + ': {}'.format(varnames)) for varname, vm_type in instructions.pattern_zip( varnames, vm_types, leaf_type=instructions.Type): _merge_var(varname, vm_type, inferred_types, backend)
def testAutoBatchingMultivalueTF(self): input_ = np.array([1, 1, 1], dtype=np.int64) output = ((np.array([1, 1, 1], dtype=np.int64), np.array([3, 3, 3], dtype=np.int64)), np.array([4, 4, 4], dtype=np.int64), (np.array([5, 5, 5], dtype=np.int64), np.array([6, 6, 6], dtype=np.int64))) prog = synthetic_pattern_program() # print(prog) input_t = tf.constant(input_, dtype=np.int64) typed = type_inference.infer_types(prog, [input_t], TF_BACKEND) # print(typed) alloc = allocation_strategy.optimize(typed) lowered = lowering.lower_function_calls(alloc) # print(lowered) for expected, obtained in instructions.pattern_zip( output, self.evaluate(_execute(lowered, input_t, 15, TF_BACKEND))): self.assertAllEqual(expected, obtained)
def _run_block(graph, index, env_dict, backend): """Executes or stages one basic block. When staging, the `graph`, the `index`, and the `backend` are static. The `environment` contains the runtime values (e.g. `Tensor`s or `np.ndarray`s) being staged. Args: graph: The `instructions.ControlFlowGraph` of blocks that exist in the program being run. index: The int index of the specific block to execute or stage. env_dict: The current variable environment of the VM (for all logical threads). The program counter is included as a variable with a reserved name. backend: Object implementing required backend operations. Returns: new_environment: The new variable environment after performing the block on the relevant threads. Raises: ValueError: Invalid opcode or illegal operation (e.g. branch/read/write program counter). """ environment = inst.Environment(env_dict, backend) program_counter = environment.read(inst.pc_var) block = graph.block(index) logging.debug('Staging block %d:\n%s', index, block) mask = backend.equal(program_counter, index) def as_index(block): return backend.broadcast_to_shape_of( graph.block_index(block), program_counter) for op in block.instructions: if isinstance(op, inst.PrimOp): if (inst.pc_var in inst.pattern_flatten(op.vars_in) or inst.pc_var in inst.pattern_flatten(op.vars_out)): raise ValueError( 'PrimOp reading or writing program counter: {}'.format(op)) inputs = [inst.pattern_map(environment.read, var_pat) for var_pat in op.vars_in] with _vm_staging(): outputs = op.function(*inputs) def doit(varname, output, skip_push_mask): if varname in skip_push_mask: return environment.update(varname, output, mask) else: return environment.push(varname, output, mask) new_vars = [(varname, doit(varname, output, op.skip_push_mask)) for varname, output in inst.pattern_zip(op.vars_out, outputs)] elif isinstance(op, inst.PopOp): new_vars = [(varname, environment.pop(varname, mask)) for varname in op.vars] else: raise ValueError('Invalid instruction in block: {}'.format(type(op))) environment = inst.Environment( environment.env_dict, environment.backend, update=new_vars) op = block.terminator if isinstance(op, inst.BranchOp): if inst.pc_var == op.cond_var: raise ValueError('Branching on program counter: {}'.format(op)) condition = environment.read(op.cond_var) next_index = backend.where( condition, as_index(op.true_block), as_index(op.false_block)) environment[inst.pc_var] = environment.update(inst.pc_var, next_index, mask) elif isinstance(op, inst.GotoOp): environment[inst.pc_var] = environment.update( inst.pc_var, as_index(op.block), mask) elif isinstance(op, inst.PushGotoOp): environment[inst.pc_var] = environment.update( inst.pc_var, as_index(op.push_block), mask) environment[inst.pc_var] = environment.push( inst.pc_var, as_index(op.goto_block), mask) elif isinstance(op, inst.IndirectGotoOp): environment[inst.pc_var] = environment.pop(inst.pc_var, mask) else: raise TypeError('Unexpected op type: {}'.format(type(op))) return environment.env_dict
def _interpret(program, mask, backend, block_code_cache, *inputs): """Worker function for `execute`; operates under a mask.""" environment = inst.Environment.initialize(backend, program.var_alloc, program.var_defs, 0, backend.batch_size(mask)) for var, inp in inst.pattern_zip(program.vars_in, inputs): environment[var] = environment.push(var, inp, mask) program_counter = 0 # Index of initial block queue = ExecutionQueue(backend) while program_counter != program.graph.exit_index(): block = program.graph.block(program_counter) for split_idx, split in enumerate(_split_fn_calls(block.instructions)): if isinstance(split, inst.FunctionCallOp): op = split if (inst.pc_var in inst.pattern_flatten(op.vars_in) or inst.pc_var in inst.pattern_flatten(op.vars_out)): raise ValueError( 'FunctionCallOp reads or writes program counter: {}'. format(op)) inputs = [ inst.pattern_map(environment.read, var_pat) for var_pat in op.vars_in ] # Option: Could gather+scatter at function boundaries. Then the inner # interpreter would not need to accept the mask, but would need to # recompute the batch size and make a new mask of all ones. outputs = _invoke_fun(program, mask, backend, block_code_cache, op.function, inputs) new_vars = [(varname, environment.push(varname, output, mask)) for varname, output in inst.pattern_zip( op.vars_out, outputs)] environment = inst.Environment(environment.env_dict, environment.backend, update=new_vars) else: # This split is not a FunctionCallOp. block_code_key = (id(program.graph), program_counter, split_idx) if block_code_key not in block_code_cache: logging.info('Fill block cache for block %s', block_code_key) varnames = inst.extract_referenced_variables(split) code = backend.wrap_straightline_callable( functools.partial(_run_straightline, split, environment.backend)) block_code_cache[block_code_key] = (varnames, code) else: logging.info('Use cached code for block %s', block_code_key) varnames, code = block_code_cache[block_code_key] filtered_env = dict( { # Only pass variables relevant to these ops k: v for k, v in six.iteritems(environment.env_dict) if k in varnames }) environment = inst.Environment(environment.env_dict, environment.backend, update=code(filtered_env, mask)) op = block.terminator if isinstance(op, inst.BranchOp): if inst.pc_var == op.cond_var: raise ValueError('Branching on program counter: {}'.format(op)) condition = environment.read(op.cond_var) true_index = program.graph.block_index(op.true_block) false_index = program.graph.block_index(op.false_block) queue.enqueue(true_index, mask & condition) queue.enqueue(false_index, mask & ~condition) elif isinstance(op, inst.GotoOp): next_index = program.graph.block_index(op.block) queue.enqueue(next_index, mask) else: raise TypeError('Unexpected op type: {}'.format(type(op))) program_counter, mask = queue.dequeue() # assert the queue is now empty return inst.pattern_map(environment.read, program.vars_out)