Exemplo n.º 1
0
def _RemoveDefaultAttrs(producer_op_list, graph_def):
    """Removes unknown default attrs according to `producer_op_list`.

  Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in
  registered OpDefs) that have a default value in `producer_op_list`.

  Args:
    producer_op_list: OpList proto.
    graph_def: GraphDef proto
  """
    producer_op_dict = {op.name: op for op in producer_op_list.op}
    for node in graph_def.node:
        # Remove any default attr values that aren't in op_def.
        if node.op in producer_op_dict:
            op_def = op_def_registry.get(node.op)
            if op_def is None:
                # Some custom op registrations won't show up here. That's OK, attribute
                # stripping just won't be available.
                continue
            producer_op_def = producer_op_dict[node.op]
            # We make a copy of node.attr to iterate through since we may modify
            # node.attr inside the loop.
            for key in list(node.attr):
                if _FindAttrInOpDef(key, op_def) is None:
                    # No attr_def in consumer, look in producer.
                    attr_def = _FindAttrInOpDef(key, producer_op_def)
                    if (attr_def and attr_def.HasField('default_value')
                            and node.attr[key] == attr_def.default_value):
                        # Unknown attr had default value in producer, delete it so it can be
                        # understood by consumer.
                        del node.attr[key]
Exemplo n.º 2
0
def _get_num_inputs_outputs(op_type):
    """Returns (num_inputs, num_outputs).

  Args:
    op_type: String. The type of the Operation. Used to lookup the op in the
      registry.

  Returns:
    (num_inputs, num_outputs), for either num_inputs or num_outputs if the value
    can't be statically inferred from the OpDef alone or of the OpDef lookup
    fails, -1 is returned.
  """
    def _is_list_arg(arg):
        return arg.number_attr or arg.type_list_attr

    def _count_args(arg_defs):
        for arg in arg_defs:
            if _is_list_arg(arg):
                # Op has list type args which could be variable.
                return -1
        return len(arg_defs)

    op_def = op_def_registry.get(op_type)
    if not op_def:
        return -1, -1
    return _count_args(op_def.input_arg), _count_args(op_def.output_arg)
Exemplo n.º 3
0
def _ProcessGraphDefParam(graph_def):
    """Type-checks and possibly canonicalizes `graph_def`."""
    if not isinstance(graph_def, graph_pb2.GraphDef):
        # `graph_def` could be a dynamically-created message, so try a duck-typed
        # approach
        try:
            old_graph_def = graph_def
            graph_def = graph_pb2.GraphDef()
            graph_def.MergeFrom(old_graph_def)
        except TypeError:
            raise TypeError('graph_def must be a GraphDef proto.')
    else:
        # If we're using the graph_def provided by the caller, modify graph_def
        # in-place to add attr defaults to the NodeDefs (this is visible to the
        # caller).
        # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py
        # depends on. It might make sense to move this to meta_graph.py and have
        # import_graph_def not modify the graph_def argument (we'd have to make sure
        # this doesn't break anything else.)
        for node in graph_def.node:
            op_def = op_def_registry.get(node.op)
            if op_def is None:
                # Assume unrecognized ops are functions for now. TF_ImportGraphDef will
                # report an error if the op is actually missing.
                continue
            _SetDefaultAttrValues(node, op_def)

    return graph_def
Exemplo n.º 4
0
  def visit_FunctionDef(self, node):
    # TODO(fengliuai): create one utility method to match different apis and
    # shared it with the tfr_gen.py module.
    compose_dec = []
    for dec in node.decorator_list:
      if isinstance(dec, ast.Call):
        if isinstance(dec.func, ast.Attribute) and dec.func.attr == 'Composite':
          compose_dec.append(dec)
        if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
          compose_dec.append(dec)

    if not compose_dec:
      # skip a non-composition function
      return
    elif len(compose_dec) > 1:
      raise KeyError('More than one TF ops decomposes for.')

    all_dec_args = {}
    for arg_name, arg_value in zip(_COMPOSITE_ARG_LIST, compose_dec[0].args):
      all_dec_args[arg_name] = self.visit(arg_value)

    kw_dec_args = dict([self.visit(kw) for kw in compose_dec[0].keywords])

    if all_dec_args.keys() & kw_dec_args.keys():
      raise KeyError('More arguments than expected.')

    all_dec_args.update(kw_dec_args)

    op_name = all_dec_args['op_name']
    op_def = op_def_registry.get(op_name)
    if op_def:
      if len(all_dec_args) > 1:
        # Op has been registered, so it is a user error to specify op def.
        raise ValueError('Op has been registered: ' + op_name)
      else:
        # Op has been registered, then we don't need to generate register code.
        return

    # Validates the function inputs match what are in the decorator.
    inputs = all_dec_args.get('inputs', [])
    attrs = all_dec_args.get('attrs', [])
    expected_args = [arg.split(':')[0] for arg in inputs + attrs]
    all_func_args = self.visit(node.args)

    if len(expected_args) != len(all_func_args):
      raise KeyError('Composition arguments do not match the registration.')

    cxx_reg_code = '\nREGISTER_OP("{0}")'.format(op_name)
    for input_ in inputs:
      cxx_reg_code += '\n    .Input("{0}")'.format(input_)
    for attr in attrs:
      py_str = attr.replace('"', '\'')
      cxx_reg_code += '\n    .Attr("{0}")'.format(py_str)
    for attr in all_dec_args.get('derived_attrs', []):
      py_str = attr.replace('"', '\'')
      cxx_reg_code += '\n    .Attr("{0}")'.format(py_str)
    for output_ in all_dec_args.get('outputs', []):
      cxx_reg_code += '\n    .Output("{0}")'.format(output_)
    cxx_reg_code += ';\n'
    self.emit(cxx_reg_code)
Exemplo n.º 5
0
    def get_metaopdef(cls, name):
        """Obtain a MetaOpDef for a given string name.

        This is more flexible because it ignores things like string case
        (when the non-`raw_ops` name differs from the TF user-level API).
        """
        raw_op_name = op_def_lib.lower_op_name_to_raw.get(name.lower(), name)
        op_def = op_def_registry.get(raw_op_name)
        if op_def is not None:
            return TFlowMetaOpDef(obj=op_def)
Exemplo n.º 6
0
def _get_output_type(node: r.NodeDef, output_idx: int) -> int:
    """Return the type of the nth output of a node"""
    op = op_def_registry.get(node.op)
    output_arg = op.output_arg[output_idx]
    if output_arg.type != 0:
        return output_arg.type
    elif len(output_arg.type_attr) > 0:
        return node.attr[output_arg.type_attr].type
    else:
        raise ValueError(f'cannot determine output type of node "{node.name}"'
                         f' op={op.name}')
Exemplo n.º 7
0
    def __init__(self):
        #
        # We need this in order to construct "Const" tensors directly, since
        # the "value" attr in a meta `NodeDef` is just a NumPy array and not
        # the `TensorProto` expected by `raw_ops.Const`.
        #
        def mt_const(value, dtype, name=None):
            return tf.raw_ops.Const(value=tensor_util.make_tensor_proto(value),
                                    dtype=dtype,
                                    name=name)

        opdef = op_def_registry.get("Const")
        self.opdef_signatures[opdef.name] = self.make_opdef_sig(
            opdef, mt_const)
Exemplo n.º 8
0
def get_op_def(op_name: Text) -> Optional[OpDef]:
    """
    Get the definition for a native TF operation.
    This is useful for checking whether an operation is supported or
    to get all valid inputs and attributes.

    Args:
        op_name: Name of the native TF operation (e.g. "AddV2")

    Returns:
        Protobuf object containing the operation definition
        `None` is returned, if the operation is not registered with TF
    """
    return op_def_registry.get(op_name)
Exemplo n.º 9
0
def register_ops_if_needed(graph_ops):
  """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    tf.errors.NotFoundError: if `graph_ops` contains ops that are not in either
    python or c++ registry.
  """
  if all(op_def_registry.get(op) is not None for op in graph_ops):
    return

  # Note: Only raise missing op ValueError after trying to load ops.
  # This allows the test to exercise all the calls into TensorFlow
  # without having to write a C + python test.
  op_def_registry.sync()
  missing_ops = {op for op in graph_ops if op_def_registry.get(op) is None}
  if missing_ops:
    raise tf.errors.NotFoundError(
        None, None,
        "Graph ops missing from the python registry (%s) are also absent from "
        "the c++ registry." % missing_ops)
Exemplo n.º 10
0
    def _strip_node_default_valued_attrs(node_def):
        """Removes default valued attributes from a single node def."""
        if node_def.op in op_name_to_function:
            return

        op_def = op_def_registry.get(node_def.op)
        if op_def is None:
            return

        attrs_to_strip = set()
        for attr_name, attr_value in node_def.attr.items():
            if _is_default_attr_value(op_def, attr_name, attr_value):
                attrs_to_strip.add(attr_name)

        for attr in attrs_to_strip:
            del node_def.attr[attr]
Exemplo n.º 11
0
def _GetOpDef(op_type_name, keywords):
  """Returns the OpDef, Graph and Producer. For use in _apply_op_helper."""
  op_def = op_def_registry.get(op_type_name)
  if op_def is None:
    raise RuntimeError(f"Unrecognized Op name {op_type_name}")

  # Determine the graph context.
  try:
    # Need to flatten all the arguments into a list.
    # pylint: disable=protected-access
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    producer = g.graph_def_versions.producer
    # pylint: enable=protected-access
  except AssertionError as e:
    raise RuntimeError(
        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")

  return op_def, g, producer
Exemplo n.º 12
0
def fix_node_def(node_def, functions, shared_name_suffix, debug_name):
    """Replace functions calls and shared names in `node_def`."""
    if ("_gradient_op_type" in node_def.attr and node_def.op
            not in ["StatefulPartitionedCall", "PartitionedCall"]):
        logging.warning(
            "Importing a function (%s) with ops with custom gradients. Will likely "
            "fail if a gradient is requested.", debug_name)
    if node_def.op in functions:
        node_def.op = functions[node_def.op].name
    for _, attr_value in node_def.attr.items():
        if attr_value.WhichOneof("value") == "func":
            attr_value.func.name = functions[attr_value.func.name].name
        elif attr_value.WhichOneof("value") == "list":
            for fn in attr_value.list.func:
                fn.name = functions[fn.name].name

    # Fix old table creation bug.
    if node_def.op == "HashTableV2":
        if ("use_node_name_sharing" not in node_def.attr
                or not node_def.attr["use_node_name_sharing"].b):
            node_def.attr["use_node_name_sharing"].b = True
            # We are turning on node mame sharing, so have to make sure we don't
            # accidentally share a table resource.
            shared_name_suffix += "_{}".format(ops.uid())

    # TODO(b/124205571): Avoid accidental sharing and destruction of restored
    # resources. For now uniquify "shared_name" when loading functions to avoid
    # sharing.
    # TODO: Add regression test for b/150826922.
    op_def = op_def_registry.get(node_def.op)
    if op_def:
        attr = next((a for a in op_def.attr if a.name == "shared_name"), None)
        if attr:
            shared_name = None
            if "shared_name" in node_def.attr and node_def.attr[
                    "shared_name"].s:
                shared_name = node_def.attr["shared_name"].s
            elif attr.default_value.s:
                shared_name = compat.as_bytes(attr.default_value.s)
            if not shared_name:
                shared_name = compat.as_bytes(node_def.name)

            node_def.attr["shared_name"].s = (
                shared_name + compat.as_bytes(shared_name_suffix))
Exemplo n.º 13
0
    def __init__(self, op_def, node_def, inputs, outputs=None, obj=None):
        """Create a TensorFlow meta `Operation`.

        The real signature of `tf.Operation.__init__` includes the graph
        object, so we can't really the signature directly.  This is part of the
        reason why we have `TFlowMetaOpFactory.__call__` and
        `TFlowMetaTensor.operator` + `TFlowMetaTensor.inputs` that do not
        directly use `__all_props__`/`TFlowMetaTensor.rands` and construct the
        objects directly.
        """
        super().__init__(obj=obj)

        if isinstance(op_def, str):
            op_def = op_def_registry.get(op_def)

        self.op_def = metatize(op_def)
        self.node_def = metatize(node_def)

        if isvar(inputs):
            self.inputs = inputs
        else:
            # Inputs are supposed to be immutable, so we're able to convert
            # lists to tuples.
            def _convert_inputs(arg, nested):
                if nested and isinstance(arg, list):
                    arg = tuple(metatize(i) for i in arg)
                else:
                    arg = metatize(arg)

                return arg

            if not isvar(self.op_def):
                self.inputs = tuple(
                    _convert_inputs(i, hasattr(info, "number_attr"))
                    for i, info in zip(inputs, self.op_def.obj.input_arg))
            else:
                self.inputs = tuple(_convert_inputs(i, False) for i in inputs)

        if outputs is not None:
            if isvar(outputs):
                self._outputs = outputs
            else:
                self._outputs = tuple(metatize(o) for o in outputs)
Exemplo n.º 14
0
    def _get_ref_args(self, node):
        """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
        op_def = op_def_registry.get(node.op)
        if op_def is None:
            return []

        ref_args = []
        for i, output_arg in enumerate(op_def.output_arg):
            if output_arg.is_ref:
                arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
                ref_args.append(arg_name)
        return ref_args
Exemplo n.º 15
0
  def lookup(self, f_name, func_def=None, optional=False):
    if f_name in self._op_defs.keys():
      return self._op_defs[f_name]

    if isinstance(func_def, types.FunctionType):
      if not hasattr(func_def, '_tfr_op_name'):
        # skip a non-composition function
        if optional:
          return (None, None)
        else:
          raise KeyError('OpDef does not exist: ' + f_name)
      op_name = getattr(func_def, '_tfr_op_name')
    elif not func_def:
      op_name = f_name
    else:
      # TODO(fengliuai): create one utility method to match different apis.
      compose_dec = []
      for dec in func_def.decorator_list:
        if isinstance(dec, ast.Call):
          if isinstance(dec.func,
                        ast.Attribute) and dec.func.attr == 'Composite':
            compose_dec.append(dec)
          if isinstance(dec.func, ast.Name) and dec.func.id == 'Composite':
            compose_dec.append(dec)

      if not compose_dec:
        # skip a non-composition function
        if optional:
          return (None, None)
        else:
          raise KeyError('OpDef does not exist: ' + f_name)
      elif len(compose_dec) > 1:
        raise KeyError('More than one TF ops decomposes for.')
      else:
        op_name = compose_dec[0].args[0].value

    op_def = op_def_registry.get(op_name)
    if not op_def:
      raise ValueError('Not a registered op: ' + op_name)
    derived_attrs = _collect_derived_attrs_from_proto(op_def)
    self._op_defs[f_name] = (op_def, derived_attrs)
    return (op_def, derived_attrs)
Exemplo n.º 16
0
    def get_op_info(cls, opdef):
        """Return the TF Python API function signature for a given `OpDef`.

        Parameter
        ---------
           opdef: str or `OpDef` object (meta or base)
        """
        if isinstance(opdef, str):
            opdef_name = opdef
            opdef = op_def_registry.get(opdef_name)
        else:
            opdef_name = opdef.name

        opdef_sig = cls.opdef_signatures.get(opdef_name, None)

        if opdef_sig is None and opdef is not None:
            opdef_func = getattr(tf.raw_ops, opdef.name, None)
            opdef_sig = cls.make_opdef_sig(opdef, opdef_func)
            cls.opdef_signatures[opdef.name] = opdef_sig

        return opdef_sig
Exemplo n.º 17
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    op_defs = []
    for op in sorted(used_ops):
        op_def = op_def_registry.get(op)
        if op_def is not None:
            op_defs.append(op_def)
        elif op not in op_whitelist:
            raise ValueError(
                "Op %s is used by the graph, but is not registered" % op)

    return op_def_pb2.OpList(op=op_defs)
Exemplo n.º 18
0
def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.
  """
    # This is similar to StrippedOpListForGraph in C++, but unlike its
    # C++ counterpart, this version does not require all ops to be registered.
    # This is done to support Prelu fusion in tfjs.
    used_ops = ops_used_by_graph_def(graph_def)
    op_defs = []
    for op in sorted(used_ops):
        op_def = op_def_registry.get(op)
        if op_def is not None:
            op_defs.append(op_def)
    return op_def_pb2.OpList(op=op_defs)
Exemplo n.º 19
0
  def __exit__(self, unused_type, unused_value, unused_traceback):
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Graph changed while trying to add control dependencies.")

    # pylint: disable=protected-access
    if hasattr(self._graph, "outer_graph"):
      outer_val = self._graph.outer_graph._add_control_dependencies
      self._graph._add_control_dependencies = outer_val
    else:
      self._graph._add_control_dependencies = False
    # pylint: enable=protected-access

    # map from resource tensor to the last op which used it
    last_op_using_resource_tensor = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_op_using_resource_tensor).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if control_flow_util.IsInWhileLoop(op):
        continue
      control_inputs = set()
      # Ensure stateful ops run
      if op_def_registry.get(op.type) is None or op_is_stateful(op):
        ops_which_must_run.add(op)
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)  # pylint: disable=protected-access
          for inp in o.inputs:
            input_id = ops.tensor_id(inp)
            if input_id in last_op_using_resource_tensor:
              last_op_using_resource_tensor[input_id] = op
        ops_which_must_run = set([op])
        continue

      resource_inputs = set()
      # Check for any resource inputs. If we find any, we update control_inputs
      # and last_op_using_resource_tensor.
      for inp in op.inputs:
        if inp.dtype != dtypes_module.resource:
          continue

        input_id = ops.tensor_id(inp)

        # If the op receives the same resource tensor twice as an input, we skip
        # to avoid the op getting a control dependency on itself.
        if input_id in resource_inputs:
          continue

        resource_inputs.add(input_id)
        # Deal with switches, finally.
        if inp.op.type == "Switch":
          self._process_switch(inp.op, ops_which_must_run,
                               last_op_using_resource_tensor,
                               merge_for_resource)
        # Ensure uses of resources are serialized
        if input_id in last_op_using_resource_tensor:
          if (last_op_using_resource_tensor[input_id]._control_flow_context  # pylint: disable=protected-access
              is op._control_flow_context):  # pylint: disable=protected-access
            control_inputs.add(last_op_using_resource_tensor[input_id])
        # Ensure merges happen after the closing of a cond block
        if input_id in merge_for_resource:
          merge_for_resource[input_id]._add_control_input(op)  # pylint: disable=protected-access
        last_op_using_resource_tensor[input_id] = op

      if (op_is_stateful(op) and not resource_inputs
          and op._control_flow_context is None):  # pylint: disable=protected-access
        if None in last_op_using_resource_tensor:
          op._add_control_input(last_op_using_resource_tensor[None])  # pylint: disable=protected-access
        last_op_using_resource_tensor[None] = op
      control_inputs = [c for c in control_inputs
                        if c._control_flow_context is op._control_flow_context]  # pylint: disable=protected-access
      op._add_control_inputs(control_inputs)  # pylint: disable=protected-access

    # Ensure all ops which must run do run
    self.ops_which_must_run.update(ops_which_must_run)
    for r in nest.flatten(list(self._returned_tensors), expand_composites=True):
      if self.ops_which_must_run:
        r.op._add_control_inputs(  # pylint: disable=protected-access
            [o for o in self.ops_which_must_run
             if o._control_flow_context is r.op._control_flow_context])  # pylint: disable=protected-access
Exemplo n.º 20
0
def convert_int64_to_int32(graph_def: r.GraphDef) -> r.GraphDef:
    """Convert int64 input to int32 for TFJS compatibility

    Args:
        graph_def: GraphDef proto containing the network layout
    Returns:
        Updated graph with int64 inputs converted to int32
    """
    inputs = util.get_input_nodes(graph_def)
    convert = [info.name for info in inputs if info.dtype == util.np.int64]
    if len(convert) == 0:
        return graph_def
    # quick access to nodes by name
    node_map = r.get_input_node_map(graph_def)
    # map of all node inputs to their referencing node and their argument index
    input_map = defaultdict(list)
    for node in graph_def.node:
        for index, name in enumerate(node.input):
            input_map[name].append((index, node))
    # type cast ops to add to the graph
    type_cast_ops = []
    # nodes that require a type cast op
    type_cast_candidates: Dict[str, Tuple[int, r.NodeDef]] = {}

    for node in map(lambda x: node_map[x], convert):
        _set_tensor_dtype(node, _DT_INT32)
        # find all nodes that reference this input and adjust their datatype
        # attributes if required
        # technical note: referenced_by is a stack, this really is a
        # depth-first recursion
        referenced_by = input_map[node.name]
        while len(referenced_by) > 0:
            idx, ref = referenced_by.pop()
            # get the input node and the index of the output tensor
            input_node, output_idx = _get_input_node(ref, idx, node_map)
            # find the description of this node's operation
            op = op_def_registry.get(ref.op)
            desc = op.input_arg[idx]
            # find out whether we can just change the input type and which
            # attributes we might need to touch
            if desc.type != 0 and desc.type != _DT_INT32:
                # input type is fixed and cannot be changed: add a type cast
                cast_op = _make_cast_node(input_node, output_idx, _DT_INT32,
                                          desc.type)
                ref.input[idx] = cast_op.name
                type_cast_ops.append(cast_op)
                node_map[cast_op.name] = cast_op
                input_map[cast_op.name].append((idx, ref))
            elif desc.type_list_attr != '' or desc.type_attr == '':
                # input arrays of potentially mixed types cannot be handled
                raise ValueError("don't know how to handle input type changes"
                                 f' for node "{ref.name}" op={ref.op}')
            else:
                # change the type of this input
                type_attr = desc.type_attr
                ref.attr[type_attr].type = _DT_INT32
                if ref.name in type_cast_candidates:
                    del type_cast_candidates[ref.name]
                # check the other inputs for type compatibility
                for i, desc in enumerate(op.input_arg):
                    if i == idx or desc.type_attr != type_attr:
                        continue  # not a matching input
                    input_node, output_idx = _get_input_node(ref, i, node_map)
                    if input_node.name in convert:
                        continue  # Placeholder that will be converted
                    src_type = _get_output_type(input_node, output_idx)
                    if src_type == _DT_INT32:
                        continue  # type matches already
                    if input_node.op == 'Const':
                        # weight tensor: harmonize_dtypes() will fix these
                        _set_tensor_dtype(input_node, _DT_INT32)
                    else:
                        # add node as a candidate for needing type cast op
                        type_cast_candidates[input_node.name] = (i, ref)
                # process any changed outputs next
                for idx, output in enumerate(op.output_arg):
                    if output.type_attr == type_attr:
                        input_name = _get_tensor_name(ref, idx)
                        referenced_by += input_map[input_name]

    for idx, ref in type_cast_candidates.values():
        # add type cast operations for all nodes that have a type mismatch
        inp_node, channel = _get_input_node(ref, idx, node_map)
        src_type = _get_output_type(inp_node, channel)
        if src_type != _DT_INT32:
            cast_op = _make_cast_node(inp_node, channel, src_type, _DT_INT32)
            ref.input[idx] = cast_op.name
            type_cast_ops.append(cast_op)
            node_map[cast_op.name] = cast_op

    graph_def.node.extend(type_cast_ops)
    return graph_def
Exemplo n.º 21
0
 def _op_def(self, op_name):
   return op_def_registry.get(op_name)
Exemplo n.º 22
0
def enable_jit_nonstateful(node_def):
    op_def = op_def_registry.get(node_def.op)
    if op_def is None:
        raise ValueError("Unregistered op being created: %s" % node_def)

    return not op_def.is_stateful
Exemplo n.º 23
0
    def __exit__(self, unused_type, unused_value, unused_traceback):
        # pylint: disable=protected-access
        if context.executing_eagerly():
            return

        if self._graph is not ops.get_default_graph():
            raise RuntimeError(
                "Graph changed while trying to add control dependencies.")

        if hasattr(self._graph, "outer_graph"):
            outer_val = self._graph.outer_graph._add_control_dependencies
            self._graph._add_control_dependencies = outer_val
        else:
            self._graph._add_control_dependencies = False

        # map from resource tensor to the last op which wrote to it
        last_write_to_resource = {}
        # map from resource tensor to the list of reads from it since the last
        # write or since the beginning of the function.
        reads_since_last_write_to_resource = collections.defaultdict(list)
        # CollectiveManager manager_ids within a particular function call should not
        # be needed outside of that function call. So we keep them separate (though
        # the general idea of the maps is the same, in the future, we'll need to
        # correctly thread the control output outside).
        # Map from collective manager scope to the last op which used it
        collective_manager_scopes_opened = {}
        collective_manager_scopes_used = {}
        # set of conditional and loop exits
        ops_which_must_run = set()
        # merge which must depend on ops which use this resource
        merge_for_resource = {}

        new_operations = self._graph.get_operations()[self._n_operations:]

        # Ensures that uses of resource tensors get serialized properly and all
        # execute. This is done by keeping a map from resource tensor to the last op
        # in graph-construction order which used it (last_write_to_resource).
        #
        # Conditionals are written in TensorFlow such that every external tensor
        # accessed in the conditional goes through a switch op and every return
        # tensor (it's guaranteed that there will be at least one) goes through a
        # merge op.
        #
        # To handle conditionals, switches are handled in a special way (see
        # comments for _process_switch). Merge nodes created by TF's conditional
        # logic (as opposed to by _process_switch) are forced to run and also get a
        # control dependency added to them to ensure all stateful ops inside their
        # control flow context run.
        #
        # We also ensure that if an op is using a resource output by a switch node
        # (that is, a resource tensor for which there's a value in
        # merge_for_resource) this op will run before the merge for that resource.
        #
        # We try to add control inputs to nodes respecting their control flow
        # contexts to avoid dead nodes propagating everywhere and leading to
        # "retval[0] doesn't have value" errors. If a node gets a control dependency
        # on a dead node (i.e. a note from an untaken control flow branch) that node
        # will be marked as dead unless it's a merge node.
        #
        # TODO(apassos): serialize non-resource-taking stateful ops as well, and
        # test that it works. Support while loops. Support init_scope escaping from
        # this.
        for op in new_operations:
            # TODO(apassos) make this code safely support while loops.
            if control_flow_util.IsInWhileLoop(op):
                continue
            control_inputs = set()
            # Ensure stateful ops run
            if (op_def_registry.get(op.type) is None
                    or (op_is_stateful(op)
                        and op.type not in utils.RESOURCE_READ_OPS)):
                # TODO(srbs): Do not add functional ops to `ops_which_must_run` if
                # they only have variable reads and are otherwise stateless.
                ops_which_must_run.add(op)
            # Make a note of all opened manager_ids.
            if op.type == "NoOp":
                try:
                    collective_manager_scopes_opened[op.get_attr(
                        "_collective_manager_id")] = op
                except ValueError:
                    pass
            # Ignore switches (they're handled separately)
            if op.type == "Switch" and op.inputs[
                    0].dtype == dtypes_module.resource:
                continue
            # Make merges trigger all other computation which must run
            if op.type == "Merge":
                for o in ops_which_must_run:
                    op._add_control_input(o)
                    for inp in o.inputs:
                        input_id = ops.tensor_id(inp)
                        if input_id in last_write_to_resource:
                            last_write_to_resource[input_id] = op
                ops_which_must_run = set([op])
                continue

            resource_inputs = set()
            # Check for any resource inputs. If we find any, we update control_inputs
            # and last_write_to_resource.
            for inp, resource_type in _get_resource_inputs(op):
                is_read = resource_type == ResourceType.READ_ONLY
                input_id = ops.tensor_id(inp)

                # If the op receives the same resource tensor twice as an input, we skip
                # to avoid the op getting a control dependency on itself.
                if input_id in resource_inputs:
                    continue

                resource_inputs.add(input_id)
                # Deal with switches, finally.
                if inp.op.type == "Switch":
                    self._process_switch(inp.op, ops_which_must_run,
                                         last_write_to_resource,
                                         merge_for_resource)
                is_building_function = op.graph.building_function
                # Ensure uses of resources are serialized
                if input_id in last_write_to_resource:
                    if is_building_function or (
                            last_write_to_resource[input_id].
                            _control_flow_context is op._control_flow_context):
                        control_inputs.add(last_write_to_resource[input_id])
                # Ensure merges happen after the closing of a cond block
                if input_id in merge_for_resource:
                    merge_for_resource[input_id]._add_control_input(op)
                if is_read:
                    reads_since_last_write_to_resource[input_id].append(op)
                else:
                    control_inputs.update(
                        reads_since_last_write_to_resource[input_id])
                    reads_since_last_write_to_resource[input_id] = []
                    last_write_to_resource[input_id] = op

            if (op_is_stateful(op) and not resource_inputs
                    and op._control_flow_context is None):
                if None in last_write_to_resource:
                    op._add_control_input(last_write_to_resource[None])
                last_write_to_resource[None] = op

            # Ensure ordering of collective ops
            manager_ids = collective_manager_ids_from_op(op)
            for manager_id in manager_ids:
                if manager_id in collective_manager_scopes_opened:
                    # Chain this function call if the scope was opened.
                    op._add_control_input(
                        collective_manager_scopes_opened[manager_id])
                    collective_manager_scopes_opened[manager_id] = op
                else:
                    # If this op is in a scope not created here, create a chain starting
                    # at this op.
                    if manager_id in collective_manager_scopes_used:
                        op._add_control_input(
                            collective_manager_scopes_used[manager_id])
                    collective_manager_scopes_used[manager_id] = op

            if control_inputs and not is_building_function:
                control_inputs = [
                    c for c in control_inputs
                    if c._control_flow_context is op._control_flow_context
                ]

            op._add_control_inputs(control_inputs)

        # Ensure all ops which must run do run
        self.ops_which_must_run.update(ops_which_must_run)
        for r in nest.flatten(list(self._returned_tensors),
                              expand_composites=True):
            if self.ops_which_must_run:
                updated_ops_which_must_run = []
                if r.graph.building_function:
                    updated_ops_which_must_run = self.ops_which_must_run
                else:
                    updated_ops_which_must_run = [
                        o for o in self.ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ]
                r.op._add_control_inputs(updated_ops_which_must_run)

        self.collective_manager_ids_used = collective_manager_scopes_used
Exemplo n.º 24
0
  def __exit__(self, unused_type, unused_value, unused_traceback):
    # pylint: disable=protected-access
    if context.executing_eagerly():
      return

    if self._graph is not ops.get_default_graph():
      raise RuntimeError(
          "Within the automatic control dependency context, the default graph"
          f" cannot change. Upon entry it was {self._graph}, but on exit it"
          f" changed to {ops.get_default_graph()}")

    if hasattr(self._graph, "outer_graph"):
      outer_val = self._graph.outer_graph._add_control_dependencies
      self._graph._add_control_dependencies = outer_val
    else:
      self._graph._add_control_dependencies = False

    # map from resource tensor to the last op which wrote to it
    last_write_to_resource = {}
    # map from resource tensor to the list of reads from it since the last
    # write or since the beginning of the function.
    reads_since_last_write_to_resource = collections.defaultdict(list)
    # CollectiveManager manager_ids within a particular function call should not
    # be needed outside of that function call. So we keep them separate (though
    # the general idea of the maps is the same, in the future, we'll need to
    # correctly thread the control output outside).
    # Map from collective manager scope to the last op which used it
    collective_manager_scopes_opened = {}
    collective_manager_scopes_used = {}
    # set of conditional and loop exits
    ops_which_must_run = set()
    # merge which must depend on ops which use this resource
    merge_for_resource = {}

    new_operations = self._graph.get_operations()[self._n_operations:]
    first_use_for_res = {}
    resources_by_op = {}

    # Ensures that uses of resource tensors get serialized properly and all
    # execute. This is done by keeping a map from resource tensor to the last op
    # in graph-construction order which used it (last_write_to_resource).
    #
    # Conditionals are written in TensorFlow such that every external tensor
    # accessed in the conditional goes through a switch op and every return
    # tensor (it's guaranteed that there will be at least one) goes through a
    # merge op.
    #
    # To handle conditionals, switches are handled in a special way (see
    # comments for _process_switch). Merge nodes created by TF's conditional
    # logic (as opposed to by _process_switch) are forced to run and also get a
    # control dependency added to them to ensure all stateful ops inside their
    # control flow context run.
    #
    # We also ensure that if an op is using a resource output by a switch node
    # (that is, a resource tensor for which there's a value in
    # merge_for_resource) this op will run before the merge for that resource.
    #
    # We try to add control inputs to nodes respecting their control flow
    # contexts to avoid dead nodes propagating everywhere and leading to
    # "retval[0] doesn't have value" errors. If a node gets a control dependency
    # on a dead node (i.e. a note from an untaken control flow branch) that node
    # will be marked as dead unless it's a merge node.
    #
    # TODO(apassos): serialize non-resource-taking stateful ops as well, and
    # test that it works. Support while loops. Support init_scope escaping from
    # this.
    for op in new_operations:
      # TODO(apassos) make this code safely support while loops.
      if control_flow_util.IsInWhileLoop(op):
        continue
      control_inputs = set()

      # Ensure stateful ops run.
      # Read-only ops are added to control outputs if the read value is
      # consumed. This covers the case when the read value is returned from
      # the function since that goes through a tf.identity in mark_as_return.
      if ((op_def_registry.get(op.type) is None) or
          (op_is_stateful(op) and
           (op.type not in utils.RESOURCE_READ_OPS or
            any(output.consumers() for output in op.outputs))) or
          (op.type in MUST_RUN_ORDER_INSENSITIVE_STATEFUL_OPS)):
        ops_which_must_run.add(op)

      # Make a note of all opened manager_ids.
      if op.type == "NoOp":
        try:
          collective_manager_scopes_opened[op.get_attr(
              "_collective_manager_id")] = op
        except ValueError:
          pass
      # Ignore switches (they're handled separately)
      if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource:
        continue
      # Make merges trigger all other computation which must run
      # TODO(mdan): Don't do this. Write a transform to chains instead.
      # See core/common_runtime/control_flow_deps_to_chains.cc.
      if op.type == "Merge":
        for o in ops_which_must_run:
          op._add_control_input(o)
          for inp in o.inputs:
            input_id = ops.tensor_id(inp)
            if input_id in last_write_to_resource:
              last_write_to_resource[input_id] = op
        ops_which_must_run = set([op])
        continue

      resource_inputs = set()
      # Check for any resource inputs. If we find any, we update control_inputs
      # and last_write_to_resource.
      for inp, resource_type in _get_resource_inputs(op):
        is_read = resource_type == ResourceType.READ_ONLY
        input_id = ops.tensor_id(inp)

        # If the op receives the same resource tensor twice as an input, we skip
        # to avoid the op getting a control dependency on itself.
        if input_id in resource_inputs:
          continue

        resource_inputs.add(input_id)
        # Deal with switches, finally.
        if inp.op.type == "Switch":
          self._process_switch(inp.op, ops_which_must_run,
                               last_write_to_resource, merge_for_resource)
        is_building_function = op.graph.building_function
        # Ensure uses of resources are serialized
        if input_id in last_write_to_resource:
          if is_building_function or (
              last_write_to_resource[input_id]._control_flow_context
              is op._control_flow_context):
            control_inputs.add(last_write_to_resource[input_id])
        # Ensure merges happen after the closing of a cond block
        if input_id in merge_for_resource:
          merge_for_resource[input_id]._add_control_input(op)

        do_record = (
            self.record_initial_resource_uses and
            input_id not in first_use_for_res)

        if is_read:
          reads_list = reads_since_last_write_to_resource[input_id]
          reads_list.append(op)

          if do_record:
            # Note: this will track the entire list that
            # reads_since_last_write_to_resource maintains. Updates to it will
            # and should be tracked, until the first write is encountered. At
            # that point, reads_since_last_write_to_resource will contain a new
            # empty list. This logic relies on that behavior.
            first_use_for_res[input_id] = reads_list

        else:
          control_inputs.update(reads_since_last_write_to_resource[input_id])
          reads_since_last_write_to_resource[input_id] = []
          last_write_to_resource[input_id] = op

          if do_record:
            first_use_for_res[input_id] = [op]

      if self.record_initial_resource_uses and op_is_stateful(op):
        if resource_inputs:
          resources_by_op[op] = tuple(resource_inputs)
        else:
          if None not in first_use_for_res:
            first_use_for_res[None] = [op]
          resources_by_op[op] = (None,)

      if (op_is_stateful(op) and not resource_inputs
          and op._control_flow_context is None):
        if None in last_write_to_resource:
          op._add_control_input(last_write_to_resource[None])
        last_write_to_resource[None] = op

      # Ensure ordering of collective ops
      manager_ids = collective_manager_ids_from_op(op)
      for manager_id in manager_ids:
        if manager_id in collective_manager_scopes_opened:
          # Chain this function call if the scope was opened.
          op._add_control_input(collective_manager_scopes_opened[manager_id])
          collective_manager_scopes_opened[manager_id] = op
        else:
          # If this op is in a scope not created here, create a chain starting
          # at this op.
          if manager_id in collective_manager_scopes_used:
            op._add_control_input(collective_manager_scopes_used[manager_id])
          collective_manager_scopes_used[manager_id] = op

      if control_inputs and not is_building_function:
        control_inputs = [
            c for c in control_inputs
            if c._control_flow_context is op._control_flow_context
        ]

      op._add_control_inputs(control_inputs)

    # Record the ops which first use resources touched by "ops which must run".
    if self.record_initial_resource_uses:
      first_uses_by_output_ops = {}
      for op in ops_which_must_run:
        if op not in resources_by_op:
          # This may happen with Merge/Switch nodes which are special cased
          # above.
          continue
        for r in resources_by_op[op]:
          if op not in first_uses_by_output_ops:
            first_uses_by_output_ops[op] = set()
          first_uses_by_output_ops[op].update(first_use_for_res[r])
      # For each "op which must run", set a private attr indicating the ops that
      # used the same resources it did.
      for op in first_uses_by_output_ops:
        others = [
            other.name.encode() for other in first_uses_by_output_ops[op]
        ]
        l = attr_value_pb2.AttrValue.ListValue(s=others)
        # TODO(mdan): Is there a way which doesn't use anonymous attrs?
        op._set_attr("_res_first_used_by", attr_value_pb2.AttrValue(list=l))

    # Ensure all ops which must run do run
    self.ops_which_must_run.update(ops_which_must_run)
    control_output_op = None
    for idx, r in enumerate(
        nest.flatten(list(self._returned_tensors), expand_composites=True)):
      if self.ops_which_must_run:
        updated_ops_which_must_run = []
        if r.graph.building_function:
          # There may be many stateful ops in the graph. Adding them as
          # control inputs to each function output could create excessive
          # control edges in the graph. Thus we create an intermediate No-op
          # to chain the control dependencies between stateful ops and
          # function outputs.
          if idx == 0:
            control_output_op = control_flow_ops.no_op()
            control_output_op._set_attr("_acd_function_control_output",
                                        attr_value_pb2.AttrValue(b=True))
            control_output_op._add_control_inputs(self.ops_which_must_run)
          updated_ops_which_must_run = [control_output_op]
        else:
          updated_ops_which_must_run = [
              o for o in self.ops_which_must_run
              if o._control_flow_context is r.op._control_flow_context
          ]
        r.op._add_control_inputs(updated_ops_which_must_run)

    self.collective_manager_ids_used = collective_manager_scopes_used
Exemplo n.º 25
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
    """Implementation of apply_op that returns output_structure, op."""
    op_def = op_def_registry.get(op_type_name)
    if op_def is None:
        raise RuntimeError("Unrecognized Op name " + op_type_name)

    # Determine the graph context.
    try:
        # Need to flatten all the arguments into a list.
        # pylint: disable=protected-access
        g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
        # pylint: enable=protected-access
    except AssertionError as e:
        raise RuntimeError("Cannot determine graph for Op '%s' due to: %s" %
                           (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
        name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
        producer = g.graph_def_versions.producer
        if producer >= deprecation_version:
            raise NotImplementedError(
                ("Op %s is not available in GraphDef version %d. "
                 "It has been removed in version %d. %s.") %
                (op_type_name, producer, deprecation_version,
                 op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
        if attr_def.type != "type":
            continue
        key = attr_def.name
        if attr_def.HasField("default_value"):
            default_type_attr_map[key] = dtypes.as_dtype(
                attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

        # Perform input type inference
        inferred_from = {}
        for input_arg in op_def.input_arg:
            input_name = input_arg.name
            if input_name in keywords:
                values = keywords.pop(input_name)
            elif input_name + "_" in keywords:
                # Handle the case where the name is a keyword or built-in
                # for Python so we use the name + _ instead.
                input_name += "_"
                values = keywords.pop(input_name)
            else:
                raise TypeError("No argument for input " + input_name)

            # Goals:
            # * Convert values to Tensors if it contains constants.
            # * Verify that values is a list if that matches the input_arg's
            #   type.
            # * If the input_arg's type is determined by attrs, either set
            #   those attrs and validate those attr values are legal (if
            #   they have not yet been set) or validate the input matches
            #   the type indicated by the attrs (if they have already been
            #   inferred via an earlier input).
            # * If the input_arg has an explicit type, make sure the input
            #   conforms.

            if _IsListParameter(input_arg):
                if not _IsListValue(values):
                    raise TypeError(
                        "Expected list for '%s' argument to '%s' Op, not %s." %
                        (input_name, op_type_name, values))
                # In cases where we expect all elements of the list to have the
                # same dtype, try to cast non-Tensor elements to that type.
                dtype = None
                default_dtype = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.number_attr:
                    if input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    else:
                        for t in values:
                            if isinstance(t, ops.Tensor):
                                dtype = t.dtype
                                break

                    # dtype still not found, prefer using the default dtype
                    # from the attr.
                    if dtype is None and input_arg.type_attr in default_type_attr_map:
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                try:
                    if not input_arg.is_ref and dtype:
                        dtype = dtypes.as_dtype(dtype).base_dtype
                    values = ops.internal_convert_n_to_tensor(
                        values,
                        name=input_arg.name,
                        dtype=dtype if dtype else None,
                        preferred_dtype=default_dtype,
                        as_ref=input_arg.is_ref)
                    if input_arg.number_attr and len(
                            set(v.dtype.base_dtype for v in values)) > 1:
                        raise TypeError()  # All types should match.
                except (TypeError, ValueError):
                    # What types does the conversion function think values have?
                    observed_types = []
                    for value in values:
                        try:
                            converted_value = ops.internal_convert_to_tensor(
                                value, as_ref=input_arg.is_ref)
                            observed_types.append(
                                converted_value.dtype.base_dtype.name)
                        except (TypeError, ValueError):
                            observed_types.append(
                                "<NOT CONVERTIBLE TO TENSOR>")
                    observed = ", ".join(observed_types)

                    prefix = (
                        "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                        % (input_name, op_type_name, observed))
                    if input_arg.number_attr:
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, dtype.name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." % (prefix, dtype.name))
                        else:
                            raise TypeError("%s that don't all match." %
                                            prefix)
                    else:
                        raise TypeError("%s that are invalid. Tensors: %s" %
                                        (prefix, values))

                types = [x.dtype for x in values]
                inputs.extend(values)
            else:
                # In cases where we have an expected type, try to convert non-Tensor
                # arguments to that type.
                dtype = None
                default_dtype = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.type_attr in attrs:
                    dtype = attrs[input_arg.type_attr]
                elif input_arg.type_attr in default_type_attr_map:
                    # The dtype could not be inferred solely from the inputs,
                    # so we prefer the attr's default, so code that adds a new attr
                    # with a default is backwards compatible.
                    default_dtype = default_type_attr_map[input_arg.type_attr]

                try:
                    values = ops.internal_convert_to_tensor(
                        values,
                        name=input_arg.name,
                        dtype=dtype,
                        as_ref=input_arg.is_ref,
                        preferred_dtype=default_dtype)
                except TypeError as err:
                    if dtype is None:
                        raise err
                    else:
                        raise TypeError(
                            "Expected %s passed to parameter '%s' of op '%s', got %s of "
                            "type '%s' instead. Error: %s" %
                            (dtypes.as_dtype(dtype).name,
                             input_arg.name, op_type_name, repr(values),
                             type(values).__name__, err))
                except ValueError:
                    # What type does convert_to_tensor think it has?
                    try:
                        observed = ops.internal_convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                    except ValueError as err:
                        raise ValueError(
                            "Tried to convert '%s' to a tensor and failed. Error: %s"
                            % (input_name, err))
                    prefix = (
                        "Input '%s' of '%s' Op has type %s that does not match"
                        % (input_name, op_type_name, observed))
                    if input_arg.type != types_pb2.DT_INVALID:
                        raise TypeError(
                            "%s expected type of %s." %
                            (prefix, dtypes.as_dtype(input_arg.type).name))
                    else:
                        # Update the maps with the default, if needed.
                        k = input_arg.type_attr
                        if k in default_type_attr_map:
                            if k not in attrs:
                                attrs[k] = default_type_attr_map[k]
                                if k not in inferred_from:
                                    inferred_from[k] = "Default in OpDef"

                        raise TypeError(
                            "%s type %s of argument '%s'." %
                            (prefix, dtypes.as_dtype(
                                attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))

                types = [values.dtype]
                inputs.append(values)
            base_types = [x.base_dtype for x in types]

            if input_arg.number_attr:
                # <number-attr> * <type> or <number-attr> * <type-attr>
                if input_arg.number_attr in attrs:
                    if len(values) != attrs[input_arg.number_attr]:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d must match "
                            "length %d of argument '%s'." %
                            (input_name, op_type_name, len(values),
                             attrs[input_arg.number_attr],
                             inferred_from[input_arg.number_attr]))
                else:
                    attrs[input_arg.number_attr] = len(values)
                    inferred_from[input_arg.number_attr] = input_name
                    num_attr = _Attr(op_def, input_arg.number_attr)
                    if num_attr.has_minimum and len(values) < num_attr.minimum:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d shorter "
                            "than minimum length %d." %
                            (input_name, op_type_name, len(values),
                             num_attr.minimum))
                # All tensors must have the same base type.
                if any(bt != base_types[0] for bt in base_types):
                    raise TypeError("All tensors passed to '%s' of '%s' Op "
                                    "must have the same type." %
                                    (input_name, op_type_name))
                if input_arg.type != types_pb2.DT_INVALID:
                    # <number-attr> * <type> case
                    if base_types and base_types[0] != input_arg.type:
                        assert False, "Unreachable"
                elif input_arg.type_attr in attrs:
                    # <number-attr> * <type-attr> case, where <type-attr> already
                    # has an inferred value.
                    if base_types and base_types[0] != attrs[
                            input_arg.type_attr]:
                        assert False, "Unreachable"
                else:
                    # <number-attr> * <type-attr> case, where we are now setting
                    # the <type-attr> based on this input
                    if not base_types:
                        raise TypeError(
                            "Don't know how to infer type variable from empty input "
                            "list passed to input '%s' of '%s' Op." %
                            (input_name, op_type_name))
                    attrs[input_arg.type_attr] = base_types[0]
                    inferred_from[input_arg.type_attr] = input_name
                    type_attr = _Attr(op_def, input_arg.type_attr)
                    _SatisfiesTypeConstraint(base_types[0],
                                             type_attr,
                                             param_name=input_name)
            elif input_arg.type_attr:
                # <type-attr>
                attr_value = base_types[0]
                if input_arg.type_attr in attrs:
                    if attrs[input_arg.type_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type %s that does not "
                            "match type %s of argument '%s'." %
                            (input_name, op_type_name,
                             dtypes.as_dtype(attr_value).name,
                             dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(op_def,
                                                       input_arg.type_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_attr] = attr_value
                    inferred_from[input_arg.type_attr] = input_name
            elif input_arg.type_list_attr:
                # <type-list-attr>
                attr_value = base_types
                if input_arg.type_list_attr in attrs:
                    if attrs[input_arg.type_list_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type list of %s that does not "
                            "match type list %s of argument '%s'." %
                            (input_name, op_type_name, ", ".join(
                                dtypes.as_dtype(x).name
                                for x in attr_value), ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attrs[input_arg.type_list_attr]),
                             inferred_from[input_arg.type_list_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(
                                                     op_def,
                                                     input_arg.type_list_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_list_attr] = attr_value
                    inferred_from[input_arg.type_list_attr] = input_name
            else:
                # single Tensor with specified type
                if base_types[0] != input_arg.type:
                    assert False, "Unreachable"

            if input_arg.is_ref:
                if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                    raise TypeError((
                        "'%s' Op requires that input '%s' be a mutable tensor "
                        "(e.g.: a tf.Variable)") % (op_type_name, input_name))
                input_types.extend(types)
            else:
                input_types.extend(base_types)

        # Process remaining attrs
        for attr in op_def.attr:
            # Skip attrs that have already had their values inferred
            if attr.name in attrs:
                if attr.name in keywords:
                    raise TypeError(
                        "Should not specify value for inferred attr '%s'." %
                        attr.name)
                continue
            if attr.name in keywords:
                attrs[attr.name] = keywords.pop(attr.name)
            elif attr.name + "_" in keywords:
                # Attrs whose names match Python keywords have an extra '_'
                # appended, so we must check for that as well.
                attrs[attr.name] = keywords.pop(attr.name + "_")
            else:
                raise TypeError("No argument for attr " + attr.name)

        # Convert attr values to AttrValue protos.
        attr_protos = {}
        for attr_def in op_def.attr:
            key = attr_def.name
            value = attrs[key]
            attr_value = attr_value_pb2.AttrValue()
            if attr_def.HasField("default_value") and value is None:
                attr_value.CopyFrom(attr_def.default_value)
                attr_protos[key] = attr_value
                continue
            if attr_def.type.startswith("list("):
                if not _IsListValue(value):
                    raise TypeError("Expected list for attr " + key)
                if attr_def.has_minimum:
                    if len(value) < attr_def.minimum:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed list of length %d "
                            "less than minimum %d." %
                            (key, op_type_name, len(value), attr_def.minimum))
                attr_value.list.SetInParent()
            if attr_def.type == "string":
                attr_value.s = _MakeStr(value, key)
                if attr_def.HasField("allowed_values"):
                    if attr_value.s not in attr_def.allowed_values.list.s:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                            % (key, op_type_name, compat.as_text(
                                attr_value.s), '", "'.join(
                                    map(compat.as_text,
                                        attr_def.allowed_values.list.s))))
            elif attr_def.type == "list(string)":
                attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                if attr_def.HasField("allowed_values"):
                    for x in attr_value.list.s:
                        if x not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name, compat.as_text(x),
                                   '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
            elif attr_def.type == "int":
                attr_value.i = _MakeInt(value, key)
                if attr_def.has_minimum:
                    if attr_value.i < attr_def.minimum:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed %d less than minimum %d."
                            % (key, op_type_name, attr_value.i,
                               attr_def.minimum))
            elif attr_def.type == "list(int)":
                attr_value.list.i.extend([_MakeInt(x, key) for x in value])
            elif attr_def.type == "float":
                attr_value.f = _MakeFloat(value, key)
            elif attr_def.type == "list(float)":
                attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
            elif attr_def.type == "bool":
                attr_value.b = _MakeBool(value, key)
            elif attr_def.type == "list(bool)":
                attr_value.list.b.extend([_MakeBool(x, key) for x in value])
            elif attr_def.type == "type":
                attr_value.type = _MakeType(value, attr_def)
            elif attr_def.type == "list(type)":
                attr_value.list.type.extend(
                    [_MakeType(x, attr_def) for x in value])
            elif attr_def.type == "shape":
                attr_value.shape.CopyFrom(_MakeShape(value, key))
            elif attr_def.type == "list(shape)":
                attr_value.list.shape.extend(
                    [_MakeShape(x, key) for x in value])
            elif attr_def.type == "tensor":
                attr_value.tensor.CopyFrom(_MakeTensor(value, key))
            elif attr_def.type == "list(tensor)":
                attr_value.list.tensor.extend(
                    [_MakeTensor(x, key) for x in value])
            elif attr_def.type == "func":
                attr_value.func.CopyFrom(_MakeFunc(value, key))
            elif attr_def.type == "list(func)":
                attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
            else:
                raise TypeError("Unrecognized Attr type " + attr_def.type)

            attr_protos[key] = attr_value
        del attrs  # attrs is no longer authoritative, use attr_protos instead

        # Determine output types (possibly using attrs)
        output_structure = []
        for arg in op_def.output_arg:
            if arg.number_attr:
                n = _AttrValue(attr_protos, arg.number_attr).i
                output_structure.append(n)
            elif arg.type_attr:
                t = _AttrValue(attr_protos, arg.type_attr)
                output_structure.append(None)
            elif arg.type_list_attr:
                t = _AttrValue(attr_protos, arg.type_list_attr)
                output_structure.append(len(t.list.type))
            else:
                output_structure.append(None)

        if keywords:
            raise TypeError("apply_op() got unexpected keyword arguments: " +
                            ", ".join(sorted(keywords.keys())))

        # NOTE(mrry): We add an explicit colocation constraint between
        # the newly created op and any of its reference-typed inputs.
        must_colocate_inputs = [
            val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
        ]
        with _MaybeColocateWith(must_colocate_inputs):
            # Add Op to graph
            # pylint: disable=protected-access
            op = g._create_op_internal(op_type_name,
                                       inputs,
                                       dtypes=None,
                                       name=scope,
                                       input_types=input_types,
                                       attrs=attr_protos,
                                       op_def=op_def)

        # `outputs` is returned as a separate return value so that the output
        # tensors can the `op` per se can be decoupled so that the
        # `op_callbacks` can function properly. See framework/op_callbacks.py
        # for more details.
        outputs = op.outputs
        # Conditionally invoke tfdbg v2's op callback(s).
        if op_callbacks.should_invoke_op_callbacks():
            callback_outputs = op_callbacks.invoke_op_callbacks(
                op.node_def.op,
                tuple(op.inputs),
                attr_protos,
                tuple(outputs),
                op_name=op.name,
                graph=g)
            if callback_outputs is not None:
                outputs = callback_outputs

        return output_structure, op_def.is_stateful, op, outputs
Exemplo n.º 26
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
  """Implementation of apply_op that returns output_structure, op."""
  op_def = op_def_registry.get(op_type_name)
  if op_def is None:
    raise RuntimeError(f"Unrecognized Op name {op_type_name}")

  # Determine the graph context.
  try:
    # Need to flatten all the arguments into a list.
    # pylint: disable=protected-access
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    # pylint: enable=protected-access
  except AssertionError as e:
    raise RuntimeError(
        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")

  # Default name if not specified.
  if name is None:
    name = op_type_name

  # Check for deprecation
  deprecation_version = op_def.deprecation.version
  if deprecation_version:
    producer = g.graph_def_versions.producer
    if producer >= deprecation_version:
      raise NotImplementedError(
          f"Op {op_type_name} is not available in GraphDef version {producer}. "
          f"It has been removed in version {deprecation_version}. "
          f"{op_def.deprecation.explanation}.")

  # Fill in the list of default types for all "type" attrs.  This
  # will be used to choose a preferred dtype to convert to in the
  # absence of input type information.
  #
  # TODO(b/31302892): Currently the defaults don't work in the right
  # way if you have two inputs, one of whose type resolution depends
  # on the other.  Handling this will require restructuring this code
  # significantly.
  default_type_attr_map = {}
  allowed_list_attr_map = {}
  for attr_def in op_def.attr:
    if attr_def.type != "type":
      continue
    key = attr_def.name
    if attr_def.HasField("default_value"):
      default_type_attr_map[key] = dtypes.as_dtype(
          attr_def.default_value.type)
    if attr_def.HasField("allowed_values"):
      allowed_list_attr_map[key] = attr_def.allowed_values.list.type

  # Requires that op_def has passed validation (using the C++
  # ValidateOpDef() from ../framework/op_def_util.h).
  attrs = {}
  inputs = []
  input_types = []
  with g.as_default(), ops.name_scope(name) as scope:

    # Perform input type inference
    inferred_from = {}
    for input_arg in op_def.input_arg:
      input_name = input_arg.name
      if input_name in keywords:
        values = keywords.pop(input_name)
      elif input_name + "_" in keywords:
        # Handle the case where the name is a keyword or built-in
        # for Python so we use the name + _ instead.
        input_name += "_"
        values = keywords.pop(input_name)
      else:
        raise TypeError(f"No argument for input {input_name} found in {op_def}")

      # Goals:
      # * Convert values to Tensors if it contains constants.
      # * Verify that values is a list if that matches the input_arg's
      #   type.
      # * If the input_arg's type is determined by attrs, either set
      #   those attrs and validate those attr values are legal (if
      #   they have not yet been set) or validate the input matches
      #   the type indicated by the attrs (if they have already been
      #   inferred via an earlier input).
      # * If the input_arg has an explicit type, make sure the input
      #   conforms.

      if _IsListParameter(input_arg):
        if not _IsListValue(values):
          raise TypeError(
              f"Expected list for '{input_name}' argument to '{op_type_name}' "
              f"Op, not {values}.")
        # In cases where we expect all elements of the list to have the
        # same dtype, try to cast non-Tensor elements to that type.
        dtype = None
        default_dtype = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.number_attr:
          if input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          else:
            for t in values:
              if isinstance(t, ops.Tensor):
                dtype = t.dtype
                break

          # dtype still not found, prefer using the default dtype
          # from the attr.
          if dtype is None and input_arg.type_attr in default_type_attr_map:
            default_dtype = default_type_attr_map[input_arg.type_attr]

        try:
          if not input_arg.is_ref and dtype:
            dtype = dtypes.as_dtype(dtype).base_dtype
          values = ops.internal_convert_n_to_tensor(
              values,
              name=input_arg.name,
              dtype=dtype if dtype else None,
              preferred_dtype=default_dtype,
              as_ref=input_arg.is_ref)
          all_types = set(v.dtype.base_dtype for v in values)
          if input_arg.number_attr and len(all_types) > 1:
            # All types should match.
            raise TypeError(f"Not all types matched for {input_arg.name} for "
                            f"{op_type_name}. Got {all_types}")
        except (TypeError, ValueError):
          # What types does the conversion function think values have?
          observed_types = []
          for value in values:
            try:
              converted_value = ops.convert_to_tensor(
                  value, as_ref=input_arg.is_ref)
              observed_types.append(converted_value.dtype.base_dtype.name)
            except (TypeError, ValueError):
              observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
          observed = ", ".join(observed_types)

          prefix = (
              "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
              (input_name, op_type_name, observed))
          if input_arg.number_attr:
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError(f"{prefix} that do not match expected type "
                              f"{dtype.name}.")
            elif input_arg.type_attr in attrs:
              raise TypeError(f"{prefix} that do not match type {dtype.name} "
                              "inferred from earlier arguments.")
            else:
              raise TypeError(f"{prefix} that don't all match.")
          else:
            raise TypeError(f"{prefix} that are invalid. Tensors: {values}")

        types = [x.dtype for x in values]
        inputs.extend(values)
      else:
        # In cases where we have an expected type, try to convert non-Tensor
        # arguments to that type.
        dtype = None
        default_dtype = None
        allowed_list = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.type_attr in attrs:
          dtype = attrs[input_arg.type_attr]
        elif input_arg.type_attr in default_type_attr_map:
          # The dtype could not be inferred solely from the inputs,
          # so we prefer the attr's default, so code that adds a new attr
          # with a default is backwards compatible.
          default_dtype = default_type_attr_map[input_arg.type_attr]
          allowed_list = allowed_list_attr_map.get(input_arg.type_attr)

        try:
          # First see if we can get a valid dtype with the default conversion
          # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
          # not list allowed dtypes, in which case we should skip this.
          if dtype is None and allowed_list:
            inferred = None
            try:
              inferred = ops.convert_to_tensor(
                  values, name=input_arg.name, as_ref=input_arg.is_ref)
            except TypeError as err:
              # When converting a python object such as a list of Dimensions, we
              # need a dtype to be specified, thus tensor conversion may throw
              # an exception which we will ignore and try again below.
              pass

            # If we did not match an allowed dtype, try again with the default
            # dtype. This could be because we have an empty tensor and thus we
            # picked the wrong type.
            if inferred is not None and inferred.dtype in allowed_list:
              values = inferred
            else:
              values = ops.convert_to_tensor(
                  values,
                  name=input_arg.name,
                  as_ref=input_arg.is_ref,
                  preferred_dtype=default_dtype)
          else:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
        except TypeError as err:
          if dtype is None:
            raise err
          else:
            raise TypeError(
                f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
                f"'{input_arg.name}' of op '{op_type_name}', got "
                f"{repr(values)} of type '{type(values).__name__}' instead. "
                f"Error: {err}")
        except ValueError:
          # What type does convert_to_tensor think it has?
          try:
            observed = ops.convert_to_tensor(
                values, as_ref=input_arg.is_ref).dtype.name
          except ValueError as err:
            raise ValueError(
                f"Tried to convert '{input_name}' to a tensor and failed. "
                f"Error: {err}")
          prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                    (input_name, op_type_name, observed))
          if input_arg.type != types_pb2.DT_INVALID:
            raise TypeError(f"{prefix} expected type of "
                            f"{dtypes.as_dtype(input_arg.type).name}.")
          else:
            # Update the maps with the default, if needed.
            k = input_arg.type_attr
            if k in default_type_attr_map:
              if k not in attrs:
                attrs[k] = default_type_attr_map[k]
                if k not in inferred_from:
                  inferred_from[k] = "Default in OpDef"

            raise TypeError(
                f"{prefix} type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")

        types = [values.dtype]
        inputs.append(values)
      base_types = [x.base_dtype for x in types]

      if input_arg.number_attr:
        # <number-attr> * <type> or <number-attr> * <type-attr>
        if input_arg.number_attr in attrs:
          if len(values) != attrs[input_arg.number_attr]:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} must match length "
                f"{attrs[input_arg.number_attr]} of argument "
                f"'{inferred_from[input_arg.number_attr]}'.")
        else:
          attrs[input_arg.number_attr] = len(values)
          inferred_from[input_arg.number_attr] = input_name
          num_attr = _Attr(op_def, input_arg.number_attr)
          if num_attr.has_minimum and len(values) < num_attr.minimum:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} shorter than minimum length "
                f"{num_attr.minimum}.")
        # All tensors must have the same base type.
        if any(bt != base_types[0] for bt in base_types):
          raise TypeError(
              f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
              f"must have the same type. Got {base_types} instead.")
        if input_arg.type != types_pb2.DT_INVALID:
          # <number-attr> * <type> case
          if base_types and base_types[0] != input_arg.type:
            assert False, "Unreachable"
        elif input_arg.type_attr in attrs:
          # <number-attr> * <type-attr> case, where <type-attr> already
          # has an inferred value.
          if base_types and base_types[0] != attrs[input_arg.type_attr]:
            assert False, "Unreachable"
        else:
          # <number-attr> * <type-attr> case, where we are now setting
          # the <type-attr> based on this input
          if not base_types:
            # If it's in default_type_attr_map, then wait to set it
            # (in "process remaining attrs", below).
            if input_arg.type_attr not in default_type_attr_map:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  f"list passed to input '{input_name}' of '{op_type_name}' "
                  "Op.")
          else:
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr,
                                     param_name=input_name)
      elif input_arg.type_attr:
        # <type-attr>
        attr_value = base_types[0]
        if input_arg.type_attr in attrs:
          if attrs[input_arg.type_attr] != attr_value:
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type "
                f"{dtypes.as_dtype(attr_value).name} that does not match type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_attr),
                                     param_name=input_name)
          attrs[input_arg.type_attr] = attr_value
          inferred_from[input_arg.type_attr] = input_name
      elif input_arg.type_list_attr:
        # <type-list-attr>
        attr_value = base_types
        if input_arg.type_list_attr in attrs:
          if attrs[input_arg.type_list_attr] != attr_value:
            actual_types = ", ".join(
                dtypes.as_dtype(x).name for x in attr_value)
            expected_types = ", ".join(
                dtypes.as_dtype(x).name
                for x in attrs[input_arg.type_list_attr])
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type list of "
                f"{actual_types} that does not match type list {expected_types}"
                f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_list_attr),
                                     param_name=input_name)
          attrs[input_arg.type_list_attr] = attr_value
          inferred_from[input_arg.type_list_attr] = input_name
      else:
        # single Tensor with specified type
        if base_types[0] != input_arg.type:
          assert False, "Unreachable"

      if input_arg.is_ref:
        if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
          raise TypeError(
              f"'{op_type_name}' Op requires that input '{input_name}' be a "
              "mutable tensor (e.g.: a tf.Variable)")
        input_types.extend(types)
      else:
        input_types.extend(base_types)

    # Process remaining attrs
    for attr in op_def.attr:
      # Skip attrs that have already had their values inferred
      if attr.name in attrs:
        if attr.name in keywords:
          raise TypeError(
              f"Should not specify value for inferred attr '{attr.name}' for "
              f"{op_type_name}.")
        continue
      if attr.name in keywords:
        attrs[attr.name] = keywords.pop(attr.name)
      elif attr.name + "_" in keywords:
        # Attrs whose names match Python keywords have an extra '_'
        # appended, so we must check for that as well.
        attrs[attr.name] = keywords.pop(attr.name + "_")
      elif attr.name in default_type_attr_map:
        attrs[attr.name] = default_type_attr_map[attr.name]
        inferred_from.setdefault(attr.name, "Default in OpDef")
      else:
        raise TypeError(f"No argument found for attr {attr.name} for "
                        f"{op_type_name}")

    # Convert attr values to AttrValue protos.
    attr_protos = {}
    for attr_def in op_def.attr:
      key = attr_def.name
      value = attrs[key]

      if attr_def.HasField("default_value") and value is None:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.CopyFrom(attr_def.default_value)
        attr_protos[key] = attr_value
        continue

      attr_value = value_to_attr_value(value, attr_def.type, key)
      if attr_def.type.startswith("list("):
        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
      if attr_def.HasField("allowed_values"):
        if attr_def.type == "string":
          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
                                             op_type_name)
        elif attr_def.type == "list(string)":
          for value in attr_value.list.s:
            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
                                               op_type_name)
      if attr_def.has_minimum and attr_def.type == "int":
        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
                                       op_type_name)
      if attr_def.type == "type":
        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
      if attr_def.type == "list(type)":
        for value in attr_value.list.type:
          _SatisfiesTypeConstraint(value, attr_def, key)

      attr_protos[key] = attr_value
    del attrs  # attrs is no longer authoritative, use attr_protos instead

    # Determine output types (possibly using attrs)
    output_structure = []
    for arg in op_def.output_arg:
      if arg.number_attr:
        n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
        output_structure.append(n)
      elif arg.type_attr:
        t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
        output_structure.append(None)
      elif arg.type_list_attr:
        t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)
        output_structure.append(len(t.list.type))
      else:
        output_structure.append(None)

    if keywords:
      all_keywords = ", ".join(sorted(keywords.keys()))
      raise TypeError(f"{op_type_name} got unexpected keyword arguments: "
                      f"{all_keywords}.")

    # NOTE(mrry): We add an explicit colocation constraint between
    # the newly created op and any of its reference-typed inputs.
    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                            if arg.is_ref]
    with _MaybeColocateWith(must_colocate_inputs):
      # Add Op to graph
      # pylint: disable=protected-access
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
                                 name=scope, input_types=input_types,
                                 attrs=attr_protos, op_def=op_def)

    # `outputs` is returned as a separate return value so that the output
    # tensors can the `op` per se can be decoupled so that the
    # `op_callbacks` can function properly. See framework/op_callbacks.py
    # for more details.
    outputs = op.outputs
    # Conditionally invoke tfdbg v2's op callback(s).
    if op_callbacks.should_invoke_op_callbacks():
      callback_outputs = op_callbacks.invoke_op_callbacks(
          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
          op_name=op.name, graph=g)
      if callback_outputs is not None:
        outputs = callback_outputs

    return output_structure, op_def.is_stateful, op, outputs
Exemplo n.º 27
0
def _get_op_def(op):
    return op.op_def or op_def_registry.get(op.type)
Exemplo n.º 28
0
def is_registered_stateful_op_without_inputs(name):
  """Checks if an op is registered, stateful and does not expect inputs."""
  op_def = op_def_registry.get(name)
  return op_def is not None and (op_def.is_stateful and not op_def.input_arg)