Ejemplo n.º 1
0
def type_of_pattern(val, backend, preferred_type=None):
    """Returns the `instructions.Type` of `val`.

  Args:
    val: Pattern of backend-specific `Tensor`s or a Python or numpy constant.
    backend: Object implementing required backend operations.
    preferred_type: `instructions.Type` to prefer, if `t` is a constant.

  Returns:
    vm_type: Pattern of `instructions.TensorType` describing `t`
  """
    def at_leaf(preferred_leaf_type, obj):
        """Pattern match at a leaf of the preferred_type pattern."""
        if preferred_leaf_type is None:
            return instructions.pattern_map(backend.type_of, obj)
        if isinstance(preferred_leaf_type, instructions.TensorType):
            return backend.type_of(obj, preferred_leaf_type.dtype)
        # Otherwise, preferred_leaf_type must be a (nested) list or tuple of
        # TensorType, while obj is not a list or a tuple (of anything).  In this
        # case, pattern_map2 should have raised an error, but we can defensively
        # raise an error here as well.
        msg = 'Type mismatch: Expected structured type {}, got object {}.'.format(
            preferred_leaf_type, obj)
        raise ValueError(msg)

    if preferred_type is None:
        preferred_type = instructions.Type(None)
    return instructions.pattern_map2(at_leaf,
                                     preferred_type.tensors,
                                     val,
                                     leaf_type=instructions.TensorType)
Ejemplo n.º 2
0
def _merge_var(varname, obtained_type, inferred_types, backend):
  """Merges an updated auto-batching type for a single variable."""
  old_type = inferred_types[varname]
  new_type = instructions.pattern_map2(
      functools.partial(_merge_tensor_type, backend=backend),
      old_type.tensors, obtained_type,
      leaf_type=instructions.TensorType)
  inferred_types[varname] = instructions.Type(new_type)
  if old_type != inferred_types[varname]:
    log_debug('{}: {} -> {}'.format(varname, old_type, inferred_types[varname]))
Ejemplo n.º 3
0
def _process_block(block, visited, inferred_types, backend):
  """Executes a pass of type inference on a single `Block`."""
  for op in block.instructions:
    log_debug('handle op {}'.format(op))
    if isinstance(op, instructions.PrimOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      types_in = [inferred_types[var] for var in op.vars_in]
      # Offer type hints for cases where we need to type non-Tensor literals.
      preferred_types_out = instructions.pattern_map(
          lambda var: inferred_types[var], op.vars_out)
      with _type_inferring():
        objs_out = backend.run_on_dummies(
            op.function, _add_incompatible_batch_dim(types_in))
      types_out = _strip_batch_dim(instructions.pattern_map2(
          lambda tp, val: type_of_pattern(val, backend, preferred_type=tp),
          preferred_types_out, objs_out, leaf_type=instructions.Type))
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update PrimOp vars_out')
    elif isinstance(op, instructions.FunctionCallOp):
      if not all(_is_determined(inferred_types[var]) for var in op.vars_in):
        continue
      # First, bind op.vars_in to op.function.vars_in.
      types_in = [inferred_types[var].tensors for var in op.vars_in]
      _merge_vars(op.function.vars_in, types_in, inferred_types, backend,
                  log_message='init function vars_in')
      # Execute type inference.
      types_out = op.function.type_inference(types_in)
      for leaf in instructions.pattern_traverse(
          types_out, leaf_type=instructions.TensorType):
        if not isinstance(leaf, instructions.TensorType):
          msg = ('Expected function output type to be '
                 'a nested list or tuple of TensorType, found {}.').format(leaf)
          raise TypeError(msg)
      # To help with typing recursive base-case return literals, we seed
      # return_vars types before stepping into the function.
      _merge_vars(op.function.vars_out, types_out, inferred_types, backend,
                  log_message='update function vars_out')
      # Finally, update op.vars_out with the results of type inference.
      _merge_vars(op.vars_out, types_out, inferred_types, backend,
                  log_message='update FunctionCall vars_out')
      # Step into function. Note: it will only be visited once, if recursive.
      _process_graph(op.function.graph, visited, inferred_types, backend)