def type_of_pattern(val, backend, preferred_type=None): """Returns the `instructions.Type` of `val`. Args: val: Pattern of backend-specific `Tensor`s or a Python or numpy constant. backend: Object implementing required backend operations. preferred_type: `instructions.Type` to prefer, if `t` is a constant. Returns: vm_type: Pattern of `instructions.TensorType` describing `t` """ def at_leaf(preferred_leaf_type, obj): """Pattern match at a leaf of the preferred_type pattern.""" if preferred_leaf_type is None: return instructions.pattern_map(backend.type_of, obj) if isinstance(preferred_leaf_type, instructions.TensorType): return backend.type_of(obj, preferred_leaf_type.dtype) # Otherwise, preferred_leaf_type must be a (nested) list or tuple of # TensorType, while obj is not a list or a tuple (of anything). In this # case, pattern_map2 should have raised an error, but we can defensively # raise an error here as well. msg = 'Type mismatch: Expected structured type {}, got object {}.'.format( preferred_leaf_type, obj) raise ValueError(msg) if preferred_type is None: preferred_type = instructions.Type(None) return instructions.pattern_map2(at_leaf, preferred_type.tensors, val, leaf_type=instructions.TensorType)
def _merge_var(varname, obtained_type, inferred_types, backend): """Merges an updated auto-batching type for a single variable.""" old_type = inferred_types[varname] new_type = instructions.pattern_map2( functools.partial(_merge_tensor_type, backend=backend), old_type.tensors, obtained_type, leaf_type=instructions.TensorType) inferred_types[varname] = instructions.Type(new_type) if old_type != inferred_types[varname]: log_debug('{}: {} -> {}'.format(varname, old_type, inferred_types[varname]))
def _process_block(block, visited, inferred_types, backend): """Executes a pass of type inference on a single `Block`.""" for op in block.instructions: log_debug('handle op {}'.format(op)) if isinstance(op, instructions.PrimOp): if not all(_is_determined(inferred_types[var]) for var in op.vars_in): continue types_in = [inferred_types[var] for var in op.vars_in] # Offer type hints for cases where we need to type non-Tensor literals. preferred_types_out = instructions.pattern_map( lambda var: inferred_types[var], op.vars_out) with _type_inferring(): objs_out = backend.run_on_dummies( op.function, _add_incompatible_batch_dim(types_in)) types_out = _strip_batch_dim(instructions.pattern_map2( lambda tp, val: type_of_pattern(val, backend, preferred_type=tp), preferred_types_out, objs_out, leaf_type=instructions.Type)) _merge_vars(op.vars_out, types_out, inferred_types, backend, log_message='update PrimOp vars_out') elif isinstance(op, instructions.FunctionCallOp): if not all(_is_determined(inferred_types[var]) for var in op.vars_in): continue # First, bind op.vars_in to op.function.vars_in. types_in = [inferred_types[var].tensors for var in op.vars_in] _merge_vars(op.function.vars_in, types_in, inferred_types, backend, log_message='init function vars_in') # Execute type inference. types_out = op.function.type_inference(types_in) for leaf in instructions.pattern_traverse( types_out, leaf_type=instructions.TensorType): if not isinstance(leaf, instructions.TensorType): msg = ('Expected function output type to be ' 'a nested list or tuple of TensorType, found {}.').format(leaf) raise TypeError(msg) # To help with typing recursive base-case return literals, we seed # return_vars types before stepping into the function. _merge_vars(op.function.vars_out, types_out, inferred_types, backend, log_message='update function vars_out') # Finally, update op.vars_out with the results of type inference. _merge_vars(op.vars_out, types_out, inferred_types, backend, log_message='update FunctionCall vars_out') # Step into function. Note: it will only be visited once, if recursive. _process_graph(op.function.graph, visited, inferred_types, backend)