Exemplo n.º 1
0
    def _apply_op_helper(self, op_type_name, name=None, **keywords):
        """Implementation of apply_op that returns output_structure, op."""
        op_info = self._ops.get(op_type_name, None)
        if op_info is None:
            raise RuntimeError("Unrecognized Op name " + op_type_name)
        op_def = op_info.op_def

        # Determine the graph context.
        try:
            # Need to flatten all the arguments into a list.
            # pylint: disable=protected-access
            g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
            # pylint: enable=protected-access
        except AssertionError as e:
            raise RuntimeError(
                "Cannot determine graph for Op '%s' due to: %s" %
                (op_type_name, e.message))

        # Default name if not specified.
        if name is None:
            name = op_type_name

        # Check for deprecation
        deprecation_version = op_def.deprecation.version
        if deprecation_version:
            producer = g.graph_def_versions.producer
            if producer >= deprecation_version:
                raise NotImplementedError(
                    ("Op %s is not available in GraphDef version %d. "
                     "It has been removed in version %d. %s.") %
                    (op_type_name, producer, deprecation_version,
                     op_def.deprecation.explanation))

        # Fill in the list of default types for all "type" attrs.  This
        # will be used to choose a preferred dtype to convert to in the
        # absence of input type information.
        #
        # TODO(b/31302892): Currently the defaults don't work in the right
        # way if you have two inputs, one of whose type resolution depends
        # on the other.  Handling this will require restructuring this code
        # significantly.
        default_type_attr_map = {}
        for attr_def in op_def.attr:
            if attr_def.type != "type":
                continue
            key = attr_def.name
            if attr_def.HasField("default_value"):
                default_type_attr_map[key] = dtypes.as_dtype(
                    attr_def.default_value.type)

        # Requires that op_def has passed validation (using the C++
        # ValidateOpDef() from ../framework/op_def_util.h).
        attrs = {}
        inputs = []
        input_types = []
        with g.as_default(), ops.name_scope(name) as scope:

            # Perform input type inference
            inferred_from = {}
            for input_arg in op_def.input_arg:
                input_name = input_arg.name
                if input_name in keywords:
                    values = keywords.pop(input_name)
                elif input_name + "_" in keywords:
                    # Handle the case where the name is a keyword or built-in
                    # for Python so we use the name + _ instead.
                    input_name += "_"
                    values = keywords.pop(input_name)
                else:
                    raise TypeError("No argument for input " + input_name)

                # Goals:
                # * Convert values to Tensors if it contains constants.
                # * Verify that values is a list if that matches the input_arg's
                #   type.
                # * If the input_arg's type is determined by attrs, either set
                #   those attrs and validate those attr values are legal (if
                #   they have not yet been set) or validate the input matches
                #   the type indicated by the attrs (if they have already been
                #   inferred via an earlier input).
                # * If the input_arg has an explicit type, make sure the input
                #   conforms.

                if _IsListParameter(input_arg):
                    if not _IsListValue(values):
                        raise TypeError(
                            "Expected list for '%s' argument to '%s' Op, not %s."
                            % (input_name, op_type_name, values))
                    # In cases where we expect all elements of the list to have the
                    # same dtype, try to cast non-Tensor elements to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.number_attr:
                        if input_arg.type_attr in attrs:
                            dtype = attrs[input_arg.type_attr]
                        else:
                            for t in values:
                                if isinstance(t, ops.Tensor):
                                    dtype = t.dtype
                                    break

                        # dtype still not found, prefer using the default dtype
                        # from the attr.
                        if dtype is None and input_arg.type_attr in default_type_attr_map:
                            default_dtype = default_type_attr_map[
                                input_arg.type_attr]

                    try:
                        if not input_arg.is_ref and dtype:
                            dtype = dtypes.as_dtype(dtype).base_dtype
                        values = ops.internal_convert_n_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype if dtype else None,
                            preferred_dtype=default_dtype,
                            as_ref=input_arg.is_ref)
                        if input_arg.number_attr and len(
                                set(v.dtype.base_dtype for v in values)) > 1:
                            raise TypeError()  # All types should match.
                    except (TypeError, ValueError):
                        # What types does the conversion function think values have?
                        observed_types = []
                        for value in values:
                            try:
                                converted_value = ops.internal_convert_to_tensor(
                                    value, as_ref=input_arg.is_ref)
                                observed_types.append(
                                    converted_value.dtype.base_dtype.name)
                            except (TypeError, ValueError):
                                observed_types.append(
                                    "<NOT CONVERTIBLE TO TENSOR>")
                        observed = ", ".join(observed_types)

                        prefix = (
                            "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                            % (input_name, op_type_name, observed))
                        if input_arg.number_attr:
                            if input_arg.type != types_pb2.DT_INVALID:
                                raise TypeError(
                                    "%s that do not match expected type %s." %
                                    (prefix, dtype.name))
                            elif input_arg.type_attr in attrs:
                                raise TypeError(
                                    "%s that do not match type %s inferred from "
                                    "earlier arguments." %
                                    (prefix, dtype.name))
                            else:
                                raise TypeError("%s that don't all match." %
                                                prefix)
                        else:
                            raise TypeError(
                                "%s that are invalid. Tensors: %s" %
                                (prefix, values))

                    types = [x.dtype for x in values]
                    inputs.extend(values)
                else:
                    # In cases where we have an expected type, try to convert non-Tensor
                    # arguments to that type.
                    dtype = None
                    default_dtype = None
                    if input_arg.type != types_pb2.DT_INVALID:
                        dtype = input_arg.type
                    elif input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    elif input_arg.type_attr in default_type_attr_map:
                        # The dtype could not be inferred solely from the inputs,
                        # so we prefer the attr's default, so code that adds a new attr
                        # with a default is backwards compatible.
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                    try:
                        values = ops.internal_convert_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype,
                            as_ref=input_arg.is_ref,
                            preferred_dtype=default_dtype)
                    except TypeError as err:
                        if dtype is None:
                            raise err
                        else:
                            raise TypeError(
                                "Expected %s passed to parameter '%s' of op '%s', got %s of "
                                "type '%s' instead." %
                                (dtypes.as_dtype(dtype).name,
                                 input_arg.name, op_type_name, repr(values),
                                 type(values).__name__))
                    except ValueError:
                        # What type does convert_to_tensor think it has?
                        try:
                            observed = ops.internal_convert_to_tensor(
                                values, as_ref=input_arg.is_ref).dtype.name
                        except ValueError as err:
                            raise ValueError(
                                "Tried to convert '%s' to a tensor and failed. Error: %s"
                                % (input_name, err))
                        prefix = (
                            "Input '%s' of '%s' Op has type %s that does not match"
                            % (input_name, op_type_name, observed))
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s expected type of %s." %
                                (prefix, dtypes.as_dtype(input_arg.type).name))
                        else:
                            # Update the maps with the default, if needed.
                            k = input_arg.type_attr
                            if k in default_type_attr_map:
                                if k not in attrs:
                                    attrs[k] = default_type_attr_map[k]
                                    if k not in inferred_from:
                                        inferred_from[k] = "Default in OpDef"

                            raise TypeError(
                                "%s type %s of argument '%s'." %
                                (prefix,
                                 dtypes.as_dtype(
                                     attrs[input_arg.type_attr]).name,
                                 inferred_from[input_arg.type_attr]))

                    types = [values.dtype]
                    inputs.append(values)
                base_types = [x.base_dtype for x in types]

                if input_arg.number_attr:
                    # <number-attr> * <type> or <number-attr> * <type-attr>
                    if input_arg.number_attr in attrs:
                        if len(values) != attrs[input_arg.number_attr]:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d must match "
                                "length %d of argument '%s'." %
                                (input_name, op_type_name, len(values),
                                 attrs[input_arg.number_attr],
                                 inferred_from[input_arg.number_attr]))
                    else:
                        attrs[input_arg.number_attr] = len(values)
                        inferred_from[input_arg.number_attr] = input_name
                        num_attr = _Attr(op_def, input_arg.number_attr)
                        if num_attr.has_minimum and len(
                                values) < num_attr.minimum:
                            raise ValueError(
                                "List argument '%s' to '%s' Op with length %d shorter "
                                "than minimum length %d." %
                                (input_name, op_type_name, len(values),
                                 num_attr.minimum))
                    # All tensors must have the same base type.
                    if any(bt != base_types[0] for bt in base_types):
                        raise TypeError(
                            "All tensors passed to '%s' of '%s' Op "
                            "must have the same type." %
                            (input_name, op_type_name))
                    if input_arg.type != types_pb2.DT_INVALID:
                        # <number-attr> * <type> case
                        if base_types and base_types[0] != input_arg.type:
                            assert False, "Unreachable"
                    elif input_arg.type_attr in attrs:
                        # <number-attr> * <type-attr> case, where <type-attr> already
                        # has an inferred value.
                        if base_types and base_types[0] != attrs[
                                input_arg.type_attr]:
                            assert False, "Unreachable"
                    else:
                        # <number-attr> * <type-attr> case, where we are now setting
                        # the <type-attr> based on this input
                        if not base_types:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0],
                                                 type_attr,
                                                 param_name=input_name)
                elif input_arg.type_attr:
                    # <type-attr>
                    attr_value = base_types[0]
                    if input_arg.type_attr in attrs:
                        if attrs[input_arg.type_attr] != attr_value:
                            assert False, "Unreachable"
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(base_type,
                                                     _Attr(
                                                         op_def,
                                                         input_arg.type_attr),
                                                     param_name=input_name)
                        attrs[input_arg.type_attr] = attr_value
                        inferred_from[input_arg.type_attr] = input_name
                elif input_arg.type_list_attr:
                    # <type-list-attr>
                    attr_value = base_types
                    if input_arg.type_list_attr in attrs:
                        if attrs[input_arg.type_list_attr] != attr_value:
                            raise TypeError(
                                "Input '%s' of '%s' Op has type list of %s that does not "
                                "match type list %s of argument '%s'." %
                                (input_name, op_type_name, ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attr_value),
                                 ", ".join(
                                     dtypes.as_dtype(x).name
                                     for x in attrs[input_arg.type_list_attr]),
                                 inferred_from[input_arg.type_list_attr]))
                    else:
                        for base_type in base_types:
                            _SatisfiesTypeConstraint(
                                base_type,
                                _Attr(op_def, input_arg.type_list_attr),
                                param_name=input_name)
                        attrs[input_arg.type_list_attr] = attr_value
                        inferred_from[input_arg.type_list_attr] = input_name
                else:
                    # single Tensor with specified type
                    if base_types[0] != input_arg.type:
                        assert False, "Unreachable"

                if input_arg.is_ref:
                    if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                        raise TypeError((
                            "'%s' Op requires that input '%s' be a mutable tensor "
                            "(e.g.: a tf.Variable)") %
                                        (op_type_name, input_name))
                    input_types.extend(types)
                else:
                    input_types.extend(base_types)

            # Process remaining attrs
            for attr in op_def.attr:
                # Skip attrs that have already had their values inferred
                if attr.name in attrs:
                    if attr.name in keywords:
                        raise TypeError(
                            "Should not specify value for inferred attr '%s'."
                            % attr.name)
                    continue
                if attr.name in keywords:
                    attrs[attr.name] = keywords.pop(attr.name)
                elif attr.name + "_" in keywords:
                    # Attrs whose names match Python keywords have an extra '_'
                    # appended, so we must check for that as well.
                    attrs[attr.name] = keywords.pop(attr.name + "_")
                else:
                    raise TypeError("No argument for attr " + attr.name)

            # Convert attr values to AttrValue protos.
            attr_protos = {}
            for attr_def in op_def.attr:
                key = attr_def.name
                value = attrs[key]
                attr_value = attr_value_pb2.AttrValue()
                if attr_def.HasField("default_value") and value is None:
                    attr_value.CopyFrom(attr_def.default_value)
                    attr_protos[key] = attr_value
                    continue
                if attr_def.type.startswith("list("):
                    if not _IsListValue(value):
                        raise TypeError("Expected list for attr " + key)
                    if attr_def.has_minimum:
                        if len(value) < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed list of length %d "
                                "less than minimum %d." %
                                (key, op_type_name, len(value),
                                 attr_def.minimum))
                    attr_value.list.SetInParent()
                if attr_def.type == "string":
                    attr_value.s = _MakeStr(value, key)
                    if attr_def.HasField("allowed_values"):
                        if attr_value.s not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name,
                                   compat.as_text(attr_value.s), '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
                elif attr_def.type == "list(string)":
                    attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                    if attr_def.HasField("allowed_values"):
                        for x in attr_value.list.s:
                            if x not in attr_def.allowed_values.list.s:
                                raise ValueError(
                                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                    %
                                    (key, op_type_name, compat.as_text(x),
                                     '", "'.join(
                                         map(compat.as_text,
                                             attr_def.allowed_values.list.s))))
                elif attr_def.type == "int":
                    attr_value.i = _MakeInt(value, key)
                    if attr_def.has_minimum:
                        if attr_value.i < attr_def.minimum:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed %d less than minimum %d."
                                % (key, op_type_name, attr_value.i,
                                   attr_def.minimum))
                elif attr_def.type == "list(int)":
                    attr_value.list.i.extend([_MakeInt(x, key) for x in value])
                elif attr_def.type == "float":
                    attr_value.f = _MakeFloat(value, key)
                elif attr_def.type == "list(float)":
                    attr_value.list.f.extend(
                        [_MakeFloat(x, key) for x in value])
                elif attr_def.type == "bool":
                    attr_value.b = _MakeBool(value, key)
                elif attr_def.type == "list(bool)":
                    attr_value.list.b.extend(
                        [_MakeBool(x, key) for x in value])
                elif attr_def.type == "type":
                    attr_value.type = _MakeType(value, attr_def)
                elif attr_def.type == "list(type)":
                    attr_value.list.type.extend(
                        [_MakeType(x, attr_def) for x in value])
                elif attr_def.type == "shape":
                    attr_value.shape.CopyFrom(_MakeShape(value, key))
                elif attr_def.type == "list(shape)":
                    attr_value.list.shape.extend(
                        [_MakeShape(x, key) for x in value])
                elif attr_def.type == "tensor":
                    attr_value.tensor.CopyFrom(_MakeTensor(value, key))
                elif attr_def.type == "list(tensor)":
                    attr_value.list.tensor.extend(
                        [_MakeTensor(x, key) for x in value])
                elif attr_def.type == "func":
                    attr_value.func.CopyFrom(_MakeFunc(value, key))
                elif attr_def.type == "list(func)":
                    attr_value.list.func.extend(
                        [_MakeFunc(x, key) for x in value])
                else:
                    raise TypeError("Unrecognized Attr type " + attr_def.type)

                attr_protos[key] = attr_value
            del attrs  # attrs is no longer authoritative, use attr_protos instead

            # Determine output types (possibly using attrs)
            output_types = []
            output_structure = []
            for arg in op_def.output_arg:
                types = []
                if arg.number_attr:
                    n = _AttrValue(attr_protos, arg.number_attr).i
                    if arg.type_attr:
                        types = [_AttrValue(attr_protos, arg.type_attr).type
                                 ] * n
                    else:
                        types = [arg.type] * n
                    output_structure.append(n)
                elif arg.type_attr:
                    t = _AttrValue(attr_protos, arg.type_attr)
                    types = [t.type]
                    output_structure.append(None)
                elif arg.type_list_attr:
                    t = _AttrValue(attr_protos, arg.type_list_attr)
                    types = t.list.type
                    output_structure.append(len(types))
                else:
                    types = [arg.type]
                    output_structure.append(None)
                if arg.is_ref:
                    types = [dtypes.as_dtype(x)._as_ref for x in types]  # pylint: disable=protected-access
                output_types.extend(types)

            if keywords:
                raise TypeError(
                    "apply_op() got unexpected keyword arguments: " +
                    ", ".join(sorted(keywords.keys())))

            # NOTE(mrry): We add an explicit colocation constraint between
            # the newly created op and any of its reference-typed inputs.
            must_colocate_inputs = [
                val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
            ]
            with _MaybeColocateWith(must_colocate_inputs):
                # Add Op to graph
                op = g.create_op(op_type_name,
                                 inputs,
                                 output_types,
                                 name=scope,
                                 input_types=input_types,
                                 attrs=attr_protos,
                                 op_def=op_def)
            return output_structure, op_def.is_stateful, op
Exemplo n.º 2
0
  def apply_op(self, op_type_name, name=None, **keywords):
    # pylint: disable=g-doc-args
    """Add a node invoking a registered Op to a graph.

    Example usage:
       # input1 and input2 can be Tensors or anything ops.convert_to_tensor()
       # will convert to a Tensor.
       op_def_library.apply_op("op", input1=input1, input2=input2)
       # Can specify a node name.
       op_def_library.apply_op("op", input1=input1, name="node_name")
       # Must use keyword arguments, with the names specified in the OpDef.
       op_def_library.apply_op("op", input_name=input, attr_name=attr)

    All attrs must either be inferred from an input or specified.
    (If inferred, the attr must not be specified.)  If an attr has a default
    value specified in the Op's OpDef, then you may pass None as the value
    of that attr to get the default.

    Args:
      op_type_name: string. Must match the name field of a registered Op.
      name: string. Optional name of the created op.
      **keywords: input Tensor and attr arguments specified by name,
        and optional parameters to pass when constructing the Operation.

    Returns:
      The Tensor(s) representing the output of the operation, or the Operation
      itself if there are no outputs.

    Raises:
      RuntimeError: On some errors.
      TypeError: On some errors.
      ValueError: On some errors.
    """
    op_info = self._ops.get(op_type_name, None)
    if op_info is None:
      raise RuntimeError("Unrecognized Op name " + op_type_name)
    op_def = op_info.op_def

    # Determine the graph context.
    try:
      # Need to flatten all the arguments into a list.
      # pylint: disable=protected-access
      g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
      # pyline: enable=protected-access
    except AssertionError as e:
      raise RuntimeError(
          "Cannot determine graph for Op '%s' due to: %s"
          % (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
      name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
      producer = g.graph_def_versions.producer
      if producer >= deprecation_version:
        raise NotImplementedError(
            ("Op %s is not available in GraphDef version %d. "
             "It has been removed in version %d. %s.") %
            (op_type_name, producer, deprecation_version,
             op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
      if attr_def.type != "type":
        continue
      key = attr_def.name
      if attr_def.HasField("default_value"):
        default_type_attr_map[key] = dtypes.as_dtype(
            attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

      # Perform input type inference
      inferred_from = {}
      for input_arg in op_def.input_arg:
        input_name = input_arg.name
        if input_name in keywords:
          values = keywords.pop(input_name)
        elif input_name + "_" in keywords:
          # Handle the case where the name is a keyword or built-in
          # for Python so we use the name + _ instead.
          input_name += "_"
          values = keywords.pop(input_name)
        else:
          raise TypeError("No argument for input " + input_name)

        # Goals:
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

            # dtype still not found, prefer using the default dtype
            # from the attr.
            if dtype is None and input_arg.type_attr in default_type_attr_map:
              default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            if not input_arg.is_ref and dtype:
              dtype = dtypes.as_dtype(dtype).base_dtype
            values = ops.internal_convert_n_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype if dtype else None,
                preferred_dtype=default_dtype,
                as_ref=input_arg.is_ref)
            if input_arg.number_attr and len(
                set(v.dtype.base_dtype for v in values)) > 1:
              raise TypeError()  # All types should match.
          except (TypeError, ValueError):
            # What types does the conversion function think values have?
            observed_types = []
            for value in values:
              try:
                converted_value = ops.internal_convert_to_tensor(
                    value, as_ref=input_arg.is_ref)
                observed_types.append(converted_value.dtype.base_dtype.name)
              except (TypeError, ValueError):
                observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
            observed = ", ".join(observed_types)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.number_attr:
              if input_arg.type != types_pb2.DT_INVALID:
                raise TypeError("%s that do not match expected type %s." %
                                (prefix, dtype.name))
              elif input_arg.type_attr in attrs:
                raise TypeError("%s that do not match type %s inferred from "
                                "earlier arguments." %
                                (prefix, dtype.name))
              else:
                raise TypeError("%s that don't all match." % prefix)
            else:
              raise TypeError("%s that are invalid." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          default_dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          elif input_arg.type_attr in default_type_attr_map:
            # The dtype could not be inferred solely from the inputs,
            # so we prefer the attr's default, so code that adds a new attr
            # with a default is backwards compatible.
            default_dtype = default_type_attr_map[input_arg.type_attr]

          try:
            values = ops.internal_convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
          except TypeError as err:
            if dtype is None:
              raise err
            else:
              raise TypeError(
                  "Expected %s passed to parameter '%s' of op '%s', got %s of "
                  "type '%s' instead." %
                  (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name,
                   repr(values), type(values).__name__))
          except ValueError:
            # What type does convert_to_tensor think it has?
            observed = ops.internal_convert_to_tensor(
                values, as_ref=input_arg.is_ref).dtype.name
            prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                      (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s expected type of %s." %
                              (prefix, dtypes.as_dtype(input_arg.type).name))
            else:
              # Update the maps with the default, if needed.
              k = input_arg.type_attr
              if k in default_type_attr_map:
                if k not in attrs:
                  attrs[k] = default_type_attr_map[k]
                  if k not in inferred_from:
                    inferred_from[k] = "Default in OpDef"

              raise TypeError(
                  "%s type %s of argument '%s'." %
                  (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                   inferred_from[input_arg.type_attr]))

          types = [values.dtype]
          inputs.append(values)
        base_types = [x.base_dtype for x in types]

        if input_arg.number_attr:
          # <number-attr> * <type> or <number-attr> * <type-attr>
          if input_arg.number_attr in attrs:
            if len(values) != attrs[input_arg.number_attr]:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d must match "
                  "length %d of argument '%s'." %
                  (input_name, op_type_name, len(values),
                   attrs[input_arg.number_attr],
                   inferred_from[input_arg.number_attr]))
          else:
            attrs[input_arg.number_attr] = len(values)
            inferred_from[input_arg.number_attr] = input_name
            num_attr = _Attr(op_def, input_arg.number_attr)
            if num_attr.has_minimum and len(values) < num_attr.minimum:
              raise ValueError(
                  "List argument '%s' to '%s' Op with length %d shorter "
                  "than minimum length %d." %
                  (input_name, op_type_name, len(values), num_attr.minimum))
          # All tensors must have the same base type.
          if any([bt != base_types[0] for bt in base_types]):
            raise TypeError(
                "All tensors passed to '%s' of '%s' Op "
                "must have the same type." %
                (input_name, op_type_name))
          if input_arg.type != types_pb2.DT_INVALID:
            # <number-attr> * <type> case
            if base_types and base_types[0] != input_arg.type:
              assert False, "Unreachable"
          elif input_arg.type_attr in attrs:
            # <number-attr> * <type-attr> case, where <type-attr> already
            # has an inferred value.
            if base_types and base_types[0] != attrs[input_arg.type_attr]:
              assert False, "Unreachable"
          else:
            # <number-attr> * <type-attr> case, where we are now setting
            # the <type-attr> based on this input
            if not base_types:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  "list passed to input '%s' of '%s' Op." %
                  (input_name, op_type_name))
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr)
        elif input_arg.type_attr:
          # <type-attr>
          attr_value = base_types[0]
          if input_arg.type_attr in attrs:
            if attrs[input_arg.type_attr] != attr_value:
              assert False, "Unreachable"
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_attr))
            attrs[input_arg.type_attr] = attr_value
            inferred_from[input_arg.type_attr] = input_name
        elif input_arg.type_list_attr:
          # <type-list-attr>
          attr_value = base_types
          if input_arg.type_list_attr in attrs:
            if attrs[input_arg.type_list_attr] != attr_value:
              raise TypeError(
                  "Input '%s' of '%s' Op has type list of %s that does not "
                  "match type list %s of argument '%s'." %
                  (input_name, op_type_name,
                   ", ".join(dtypes.as_dtype(x).name for x in attr_value),
                   ", ".join(dtypes.as_dtype(x).name
                             for x in attrs[input_arg.type_list_attr]),
                   inferred_from[input_arg.type_list_attr]))
          else:
            for base_type in base_types:
              _SatisfiesTypeConstraint(base_type,
                                       _Attr(op_def, input_arg.type_list_attr))
            attrs[input_arg.type_list_attr] = attr_value
            inferred_from[input_arg.type_list_attr] = input_name
        else:
          # single Tensor with specified type
          if base_types[0] != input_arg.type:
            assert False, "Unreachable"

        if input_arg.is_ref:
          if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
            raise TypeError(
                "Input '%s' of '%s' Op requires l-value input" %
                (input_name, op_type_name))
          input_types.extend(types)
        else:
          input_types.extend(base_types)

      # Process remaining attrs
      for attr in op_def.attr:
        # Skip attrs that have already had their values inferred
        if attr.name in attrs:
          if attr.name in keywords:
            raise TypeError(
                "Should not specify value for inferred attr '%s'." % attr.name)
          continue
        if attr.name in keywords:
          attrs[attr.name] = keywords.pop(attr.name)
        elif attr.name + "_" in keywords:
          # Attrs whose names match Python keywords have an extra '_'
          # appended, so we must check for that as well.
          attrs[attr.name] = keywords.pop(attr.name + "_")
        else:
          raise TypeError("No argument for attr " + attr.name)

      # Convert attr values to AttrValue protos.
      attr_protos = {}
      for attr_def in op_def.attr:
        key = attr_def.name
        value = attrs[key]
        attr_value = attr_value_pb2.AttrValue()
        if attr_def.HasField("default_value") and value is None:
          attr_value.CopyFrom(attr_def.default_value)
          attr_protos[key] = attr_value
          continue
        if attr_def.type.startswith("list("):
          if not _IsListValue(value):
            raise TypeError("Expected list for attr " + key)
          if attr_def.has_minimum:
            if len(value) < attr_def.minimum:
              raise ValueError("Attr '%s' of '%s' Op passed list of length %d "
                               "less than minimum %d." %
                               (key, op_type_name, len(value),
                                attr_def.minimum))
          attr_value.list.SetInParent()
        if attr_def.type == "string":
          attr_value.s = _MakeStr(value, key)
          if attr_def.HasField("allowed_values"):
            if attr_value.s not in attr_def.allowed_values.list.s:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                  (key, op_type_name, compat.as_text(attr_value.s),
                   '", "'.join(map(compat.as_text,
                                   attr_def.allowed_values.list.s))))
        elif attr_def.type == "list(string)":
          attr_value.list.s.extend([_MakeStr(x, key) for x in value])
          if attr_def.HasField("allowed_values"):
            for x in attr_value.list.s:
              if x not in attr_def.allowed_values.list.s:
                raise ValueError(
                    "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." %
                    (key, op_type_name, compat.as_text(x),
                     '", "'.join(map(compat.as_text,
                                     attr_def.allowed_values.list.s))))
        elif attr_def.type == "int":
          attr_value.i = _MakeInt(value, key)
          if attr_def.has_minimum:
            if attr_value.i < attr_def.minimum:
              raise ValueError(
                  "Attr '%s' of '%s' Op passed %d less than minimum %d." %
                  (key, op_type_name, attr_value.i, attr_def.minimum))
        elif attr_def.type == "list(int)":
          attr_value.list.i.extend([_MakeInt(x, key) for x in value])
        elif attr_def.type == "float":
          attr_value.f = _MakeFloat(value, key)
        elif attr_def.type == "list(float)":
          attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
        elif attr_def.type == "bool":
          attr_value.b = _MakeBool(value, key)
        elif attr_def.type == "list(bool)":
          attr_value.list.b.extend([_MakeBool(x, key) for x in value])
        elif attr_def.type == "type":
          attr_value.type = _MakeType(value, attr_def)
        elif attr_def.type == "list(type)":
          attr_value.list.type.extend(
              [_MakeType(x, attr_def) for x in value])
        elif attr_def.type == "shape":
          attr_value.shape.CopyFrom(_MakeShape(value, key))
        elif attr_def.type == "list(shape)":
          attr_value.list.shape.extend(
              [_MakeShape(x, key) for x in value])
        elif attr_def.type == "tensor":
          attr_value.tensor.CopyFrom(_MakeTensor(value, key))
        elif attr_def.type == "list(tensor)":
          attr_value.list.tensor.extend(
              [_MakeTensor(x, key) for x in value])
        elif attr_def.type == "func":
          if isinstance(value, attr_value_pb2.NameAttrList):
            attr_value.func.CopyFrom(value)
          elif isinstance(value, compat.bytes_or_text_types):
            attr_value.func.name = value
          else:
            value.add_to_graph(ops.get_default_graph())
            attr_value.func.name = value.name
        else:
          raise TypeError("Unrecognized Attr type " + attr_def.type)

        attr_protos[key] = attr_value
      del attrs  # attrs is no longer authoritative, use attr_protos instead

      # Determine output types (possibly using attrs)
      output_types = []
      output_structure = []
      for arg in op_def.output_arg:
        types = []
        if arg.number_attr:
          n = _AttrValue(attr_protos, arg.number_attr).i
          if arg.type_attr:
            types = [_AttrValue(attr_protos, arg.type_attr).type] * n
          else:
            types = [arg.type] * n
          output_structure.append(n)
        elif arg.type_attr:
          t = _AttrValue(attr_protos, arg.type_attr)
          types = [t.type]
          output_structure.append(None)
        elif arg.type_list_attr:
          t = _AttrValue(attr_protos, arg.type_list_attr)
          types = t.list.type
          output_structure.append(len(types))
        else:
          types = [arg.type]
          output_structure.append(None)
        if arg.is_ref:
          types = [dtypes.as_dtype(x)._as_ref for x in types]  # pylint: disable=protected-access
        output_types.extend(types)

      if keywords:
        raise TypeError("apply_op() got unexpected keyword arguments: " +
                        ", ".join(sorted(keywords.keys())))

      # NOTE(mrry): We add an explicit colocation constraint between
      # the newly created op and any of its reference-typed inputs.
      must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                              if arg.is_ref]
      with _MaybeColocateWith(must_colocate_inputs):
        # Add Op to graph
        op = g.create_op(op_type_name, inputs, output_types, name=scope,
                         input_types=input_types, attrs=attr_protos,
                         op_def=op_def)
        if output_structure:
          outputs = op.outputs
          res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure)
          if isinstance(res, list) and not res and op_def.is_stateful:
            return op
          else:
            return res
        else:
          return op
Exemplo n.º 3
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
    """Implementation of apply_op that returns output_structure, op."""
    op_def = op_def_registry.get(op_type_name)
    if op_def is None:
        raise RuntimeError("Unrecognized Op name " + op_type_name)

    # Determine the graph context.
    try:
        # Need to flatten all the arguments into a list.
        # pylint: disable=protected-access
        g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
        # pylint: enable=protected-access
    except AssertionError as e:
        raise RuntimeError("Cannot determine graph for Op '%s' due to: %s" %
                           (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
        name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
        producer = g.graph_def_versions.producer
        if producer >= deprecation_version:
            raise NotImplementedError(
                ("Op %s is not available in GraphDef version %d. "
                 "It has been removed in version %d. %s.") %
                (op_type_name, producer, deprecation_version,
                 op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    allowed_list_attr_map = {}
    for attr_def in op_def.attr:
        if attr_def.type != "type":
            continue
        key = attr_def.name
        if attr_def.HasField("default_value"):
            default_type_attr_map[key] = dtypes.as_dtype(
                attr_def.default_value.type)
        if attr_def.HasField("allowed_values"):
            allowed_list_attr_map[key] = attr_def.allowed_values.list.type

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

        # Perform input type inference
        inferred_from = {}
        for input_arg in op_def.input_arg:
            input_name = input_arg.name
            if input_name in keywords:
                values = keywords.pop(input_name)
            elif input_name + "_" in keywords:
                # Handle the case where the name is a keyword or built-in
                # for Python so we use the name + _ instead.
                input_name += "_"
                values = keywords.pop(input_name)
            else:
                raise TypeError("No argument for input " + input_name)

            # Goals:
            # * Convert values to Tensors if it contains constants.
            # * Verify that values is a list if that matches the input_arg's
            #   type.
            # * If the input_arg's type is determined by attrs, either set
            #   those attrs and validate those attr values are legal (if
            #   they have not yet been set) or validate the input matches
            #   the type indicated by the attrs (if they have already been
            #   inferred via an earlier input).
            # * If the input_arg has an explicit type, make sure the input
            #   conforms.

            if _IsListParameter(input_arg):
                if not _IsListValue(values):
                    raise TypeError(
                        "Expected list for '%s' argument to '%s' Op, not %s." %
                        (input_name, op_type_name, values))
                # In cases where we expect all elements of the list to have the
                # same dtype, try to cast non-Tensor elements to that type.
                dtype = None
                default_dtype = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.number_attr:
                    if input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    else:
                        for t in values:
                            if isinstance(t, ops.Tensor):
                                dtype = t.dtype
                                break

                    # dtype still not found, prefer using the default dtype
                    # from the attr.
                    if dtype is None and input_arg.type_attr in default_type_attr_map:
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                try:
                    if not input_arg.is_ref and dtype:
                        dtype = dtypes.as_dtype(dtype).base_dtype
                    values = ops.internal_convert_n_to_tensor(
                        values,
                        name=input_arg.name,
                        dtype=dtype if dtype else None,
                        preferred_dtype=default_dtype,
                        as_ref=input_arg.is_ref)
                    if input_arg.number_attr and len(
                            set(v.dtype.base_dtype for v in values)) > 1:
                        raise TypeError()  # All types should match.
                except (TypeError, ValueError):
                    # What types does the conversion function think values have?
                    observed_types = []
                    for value in values:
                        try:
                            converted_value = ops.convert_to_tensor(
                                value, as_ref=input_arg.is_ref)
                            observed_types.append(
                                converted_value.dtype.base_dtype.name)
                        except (TypeError, ValueError):
                            observed_types.append(
                                "<NOT CONVERTIBLE TO TENSOR>")
                    observed = ", ".join(observed_types)

                    prefix = (
                        "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                        % (input_name, op_type_name, observed))
                    if input_arg.number_attr:
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, dtype.name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." % (prefix, dtype.name))
                        else:
                            raise TypeError("%s that don't all match." %
                                            prefix)
                    else:
                        raise TypeError("%s that are invalid. Tensors: %s" %
                                        (prefix, values))

                types = [x.dtype for x in values]
                inputs.extend(values)
            else:
                # In cases where we have an expected type, try to convert non-Tensor
                # arguments to that type.
                dtype = None
                default_dtype = None
                allowed_list = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.type_attr in attrs:
                    dtype = attrs[input_arg.type_attr]
                elif input_arg.type_attr in default_type_attr_map:
                    # The dtype could not be inferred solely from the inputs,
                    # so we prefer the attr's default, so code that adds a new attr
                    # with a default is backwards compatible.
                    default_dtype = default_type_attr_map[input_arg.type_attr]
                    allowed_list = allowed_list_attr_map.get(
                        input_arg.type_attr)

                try:
                    # First see if we can get a valid dtype with the default conversion
                    # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
                    # not list allowed dtypes, in which case we should skip this.
                    if dtype is None and allowed_list:
                        inferred = None
                        try:
                            inferred = ops.convert_to_tensor(
                                values,
                                name=input_arg.name,
                                as_ref=input_arg.is_ref)
                        except TypeError as err:
                            # When converting a python object such as a list of Dimensions, we
                            # need a dtype to be specified, thus tensor conversion may throw
                            # an exception which we will ignore and try again below.
                            pass

                        # If we did not match an allowed dtype, try again with the default
                        # dtype. This could be because we have an empty tensor and thus we
                        # picked the wrong type.
                        if inferred is not None and inferred.dtype in allowed_list:
                            values = inferred
                        else:
                            values = ops.convert_to_tensor(
                                values,
                                name=input_arg.name,
                                as_ref=input_arg.is_ref,
                                preferred_dtype=default_dtype)
                    else:
                        values = ops.convert_to_tensor(
                            values,
                            name=input_arg.name,
                            dtype=dtype,
                            as_ref=input_arg.is_ref,
                            preferred_dtype=default_dtype)
                except TypeError as err:
                    if dtype is None:
                        raise err
                    else:
                        raise TypeError(
                            "Expected %s passed to parameter '%s' of op '%s', got %s of "
                            "type '%s' instead. Error: %s" %
                            (dtypes.as_dtype(dtype).name,
                             input_arg.name, op_type_name, repr(values),
                             type(values).__name__, err))
                except ValueError:
                    # What type does convert_to_tensor think it has?
                    try:
                        observed = ops.convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                    except ValueError as err:
                        raise ValueError(
                            "Tried to convert '%s' to a tensor and failed. Error: %s"
                            % (input_name, err))
                    prefix = (
                        "Input '%s' of '%s' Op has type %s that does not match"
                        % (input_name, op_type_name, observed))
                    if input_arg.type != types_pb2.DT_INVALID:
                        raise TypeError(
                            "%s expected type of %s." %
                            (prefix, dtypes.as_dtype(input_arg.type).name))
                    else:
                        # Update the maps with the default, if needed.
                        k = input_arg.type_attr
                        if k in default_type_attr_map:
                            if k not in attrs:
                                attrs[k] = default_type_attr_map[k]
                                if k not in inferred_from:
                                    inferred_from[k] = "Default in OpDef"

                        raise TypeError(
                            "%s type %s of argument '%s'." %
                            (prefix, dtypes.as_dtype(
                                attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))

                types = [values.dtype]
                inputs.append(values)
            base_types = [x.base_dtype for x in types]

            if input_arg.number_attr:
                # <number-attr> * <type> or <number-attr> * <type-attr>
                if input_arg.number_attr in attrs:
                    if len(values) != attrs[input_arg.number_attr]:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d must match "
                            "length %d of argument '%s'." %
                            (input_name, op_type_name, len(values),
                             attrs[input_arg.number_attr],
                             inferred_from[input_arg.number_attr]))
                else:
                    attrs[input_arg.number_attr] = len(values)
                    inferred_from[input_arg.number_attr] = input_name
                    num_attr = _Attr(op_def, input_arg.number_attr)
                    if num_attr.has_minimum and len(values) < num_attr.minimum:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d shorter "
                            "than minimum length %d." %
                            (input_name, op_type_name, len(values),
                             num_attr.minimum))
                # All tensors must have the same base type.
                if any(bt != base_types[0] for bt in base_types):
                    raise TypeError("All tensors passed to '%s' of '%s' Op "
                                    "must have the same type." %
                                    (input_name, op_type_name))
                if input_arg.type != types_pb2.DT_INVALID:
                    # <number-attr> * <type> case
                    if base_types and base_types[0] != input_arg.type:
                        assert False, "Unreachable"
                elif input_arg.type_attr in attrs:
                    # <number-attr> * <type-attr> case, where <type-attr> already
                    # has an inferred value.
                    if base_types and base_types[0] != attrs[
                            input_arg.type_attr]:
                        assert False, "Unreachable"
                else:
                    # <number-attr> * <type-attr> case, where we are now setting
                    # the <type-attr> based on this input
                    if not base_types:
                        # If it's in default_type_attr_map, then wait to set it
                        # (in "process remaining attrs", below).
                        if input_arg.type_attr not in default_type_attr_map:
                            raise TypeError(
                                "Don't know how to infer type variable from empty input "
                                "list passed to input '%s' of '%s' Op." %
                                (input_name, op_type_name))
                    else:
                        attrs[input_arg.type_attr] = base_types[0]
                        inferred_from[input_arg.type_attr] = input_name
                        type_attr = _Attr(op_def, input_arg.type_attr)
                        _SatisfiesTypeConstraint(base_types[0],
                                                 type_attr,
                                                 param_name=input_name)
            elif input_arg.type_attr:
                # <type-attr>
                attr_value = base_types[0]
                if input_arg.type_attr in attrs:
                    if attrs[input_arg.type_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type %s that does not "
                            "match type %s of argument '%s'." %
                            (input_name, op_type_name,
                             dtypes.as_dtype(attr_value).name,
                             dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(op_def,
                                                       input_arg.type_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_attr] = attr_value
                    inferred_from[input_arg.type_attr] = input_name
            elif input_arg.type_list_attr:
                # <type-list-attr>
                attr_value = base_types
                if input_arg.type_list_attr in attrs:
                    if attrs[input_arg.type_list_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type list of %s that does not "
                            "match type list %s of argument '%s'." %
                            (input_name, op_type_name, ", ".join(
                                dtypes.as_dtype(x).name
                                for x in attr_value), ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attrs[input_arg.type_list_attr]),
                             inferred_from[input_arg.type_list_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(
                                                     op_def,
                                                     input_arg.type_list_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_list_attr] = attr_value
                    inferred_from[input_arg.type_list_attr] = input_name
            else:
                # single Tensor with specified type
                if base_types[0] != input_arg.type:
                    assert False, "Unreachable"

            if input_arg.is_ref:
                if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                    raise TypeError((
                        "'%s' Op requires that input '%s' be a mutable tensor "
                        "(e.g.: a tf.Variable)") % (op_type_name, input_name))
                input_types.extend(types)
            else:
                input_types.extend(base_types)

        # Process remaining attrs
        for attr in op_def.attr:
            # Skip attrs that have already had their values inferred
            if attr.name in attrs:
                if attr.name in keywords:
                    raise TypeError(
                        "Should not specify value for inferred attr '%s'." %
                        attr.name)
                continue
            if attr.name in keywords:
                attrs[attr.name] = keywords.pop(attr.name)
            elif attr.name + "_" in keywords:
                # Attrs whose names match Python keywords have an extra '_'
                # appended, so we must check for that as well.
                attrs[attr.name] = keywords.pop(attr.name + "_")
            elif attr.name in default_type_attr_map:
                attrs[attr.name] = default_type_attr_map[attr.name]
                inferred_from.setdefault(attr.name, "Default in OpDef")
            else:
                raise TypeError("No argument for attr " + attr.name)

        # Convert attr values to AttrValue protos.
        attr_protos = {}
        for attr_def in op_def.attr:
            key = attr_def.name
            value = attrs[key]

            if attr_def.HasField("default_value") and value is None:
                attr_value = attr_value_pb2.AttrValue()
                attr_value.CopyFrom(attr_def.default_value)
                attr_protos[key] = attr_value
                continue

            attr_value = value_to_attr_value(value, attr_def.type, key)
            if attr_def.type.startswith("list("):
                _SatisfiesLengthConstraint(len(value), attr_def, key,
                                           op_type_name)
            if attr_def.HasField("allowed_values"):
                if attr_def.type == "string":
                    _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def,
                                                       key, op_type_name)
                elif attr_def.type == "list(string)":
                    for value in attr_value.list.s:
                        _SatisfiesAllowedStringsConstraint(
                            value, attr_def, key, op_type_name)
            if attr_def.has_minimum and attr_def.type == "int":
                _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
                                               op_type_name)
            if attr_def.type == "type":
                _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
            if attr_def.type == "list(type)":
                for value in attr_value.list.type:
                    _SatisfiesTypeConstraint(value, attr_def, key)

            attr_protos[key] = attr_value
        del attrs  # attrs is no longer authoritative, use attr_protos instead

        # Determine output types (possibly using attrs)
        output_structure = []
        for arg in op_def.output_arg:
            if arg.number_attr:
                n = _AttrValue(attr_protos, arg.number_attr).i
                output_structure.append(n)
            elif arg.type_attr:
                t = _AttrValue(attr_protos, arg.type_attr)
                output_structure.append(None)
            elif arg.type_list_attr:
                t = _AttrValue(attr_protos, arg.type_list_attr)
                output_structure.append(len(t.list.type))
            else:
                output_structure.append(None)

        if keywords:
            raise TypeError("apply_op() got unexpected keyword arguments: " +
                            ", ".join(sorted(keywords.keys())))

        # NOTE(mrry): We add an explicit colocation constraint between
        # the newly created op and any of its reference-typed inputs.
        must_colocate_inputs = [
            val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
        ]
        with _MaybeColocateWith(must_colocate_inputs):
            # Add Op to graph
            # pylint: disable=protected-access
            op = g._create_op_internal(op_type_name,
                                       inputs,
                                       dtypes=None,
                                       name=scope,
                                       input_types=input_types,
                                       attrs=attr_protos,
                                       op_def=op_def)

        # `outputs` is returned as a separate return value so that the output
        # tensors can the `op` per se can be decoupled so that the
        # `op_callbacks` can function properly. See framework/op_callbacks.py
        # for more details.
        outputs = op.outputs
        # Conditionally invoke tfdbg v2's op callback(s).
        if op_callbacks.should_invoke_op_callbacks():
            callback_outputs = op_callbacks.invoke_op_callbacks(
                op.node_def.op,
                tuple(op.inputs),
                attr_protos,
                tuple(outputs),
                op_name=op.name,
                graph=g)
            if callback_outputs is not None:
                outputs = callback_outputs

        return output_structure, op_def.is_stateful, op, outputs
Exemplo n.º 4
0
def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
                           keywords, default_type_attr_map, attrs, inputs,
                           input_types):
  """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper."""
  inferred_from = {}
  for input_arg in op_def.input_arg:
    input_name = input_arg.name
    if input_name in keywords:
      values = keywords.pop(input_name)
    elif input_name + "_" in keywords:
      # Handle the case where the name is a keyword or built-in
      # for Python so we use the name + _ instead.
      input_name += "_"
      values = keywords.pop(input_name)
    else:
      raise TypeError(f"No argument for input {input_name} found in {op_def}")

    # Goals:
    # * Convert values to Tensors if it contains constants.
    # * Verify that values is a list if that matches the input_arg's
    #   type.
    # * If the input_arg's type is determined by attrs, either set
    #   those attrs and validate those attr values are legal (if
    #   they have not yet been set) or validate the input matches
    #   the type indicated by the attrs (if they have already been
    #   inferred via an earlier input).
    # * If the input_arg has an explicit type, make sure the input
    #   conforms.

    if _IsListParameter(input_arg):
      if not _IsListValue(values):
        raise TypeError(
            f"Expected list for '{input_name}' argument to '{op_type_name}' "
            f"Op, not {values}.")
      # In cases where we expect all elements of the list to have the
      # same dtype, try to cast non-Tensor elements to that type.
      dtype = None
      default_dtype = None
      if input_arg.type != types_pb2.DT_INVALID:
        dtype = input_arg.type
      elif input_arg.number_attr:
        if input_arg.type_attr in attrs:
          dtype = attrs[input_arg.type_attr]
        else:
          for t in values:
            if isinstance(t, ops.Tensor):
              dtype = t.dtype
              break

        # dtype still not found, prefer using the default dtype
        # from the attr.
        if dtype is None and input_arg.type_attr in default_type_attr_map:
          default_dtype = default_type_attr_map[input_arg.type_attr]

      try:
        if not input_arg.is_ref and dtype:
          dtype = dtypes.as_dtype(dtype).base_dtype
        values = ops.internal_convert_n_to_tensor(
            values,
            name=input_arg.name,
            dtype=dtype if dtype else None,
            preferred_dtype=default_dtype,
            as_ref=input_arg.is_ref)
        all_types = set(v.dtype.base_dtype for v in values)
        if input_arg.number_attr and len(all_types) > 1:
          # All types should match.
          raise TypeError(f"Not all types matched for {input_arg.name} for "
                          f"{op_type_name}. Got {all_types}")
      except (TypeError, ValueError):
        # What types does the conversion function think values have?
        observed_types = []
        for value in values:
          try:
            converted_value = ops.convert_to_tensor(
                value, as_ref=input_arg.is_ref)
            observed_types.append(converted_value.dtype.base_dtype.name)
          except (TypeError, ValueError):
            observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
        observed = ", ".join(observed_types)

        prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                  (input_name, op_type_name, observed))
        if input_arg.number_attr:
          if input_arg.type != types_pb2.DT_INVALID:
            raise TypeError(f"{prefix} that do not match expected type "
                            f"{dtype.name}.")
          elif input_arg.type_attr in attrs:
            raise TypeError(f"{prefix} that do not match type {dtype.name} "
                            "inferred from earlier arguments.")
          else:
            raise TypeError(f"{prefix} that don't all match.")
        else:
          raise TypeError(f"{prefix} that are invalid. Tensors: {values}")

      types = [x.dtype for x in values]
      inputs.extend(values)
    else:
      # In cases where we have an expected type, try to convert non-Tensor
      # arguments to that type.
      dtype = None
      default_dtype = None
      allowed_list = None
      if input_arg.type != types_pb2.DT_INVALID:
        dtype = input_arg.type
      elif input_arg.type_attr in attrs:
        dtype = attrs[input_arg.type_attr]
      elif input_arg.type_attr in default_type_attr_map:
        # The dtype could not be inferred solely from the inputs,
        # so we prefer the attr's default, so code that adds a new attr
        # with a default is backwards compatible.
        default_dtype = default_type_attr_map[input_arg.type_attr]
        allowed_list = allowed_list_attr_map.get(input_arg.type_attr)

      try:
        # First see if we can get a valid dtype with the default conversion
        # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
        # not list allowed dtypes, in which case we should skip this.
        if dtype is None and allowed_list:
          inferred = None
          try:
            inferred = ops.convert_to_tensor(
                values, name=input_arg.name, as_ref=input_arg.is_ref)
          except TypeError as err:
            # When converting a python object such as a list of Dimensions, we
            # need a dtype to be specified, thus tensor conversion may throw
            # an exception which we will ignore and try again below.
            pass

          # If we did not match an allowed dtype, try again with the default
          # dtype. This could be because we have an empty tensor and thus we
          # picked the wrong type.
          if inferred is not None and inferred.dtype in allowed_list:
            values = inferred
          else:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
        else:
          values = ops.convert_to_tensor(
              values,
              name=input_arg.name,
              dtype=dtype,
              as_ref=input_arg.is_ref,
              preferred_dtype=default_dtype)
      except TypeError as err:
        if dtype is None:
          raise err
        else:
          raise TypeError(
              f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
              f"'{input_arg.name}' of op '{op_type_name}', got "
              f"{repr(values)} of type '{type(values).__name__}' instead. "
              f"Error: {err}")
      except ValueError:
        # What type does convert_to_tensor think it has?
        try:
          observed = ops.convert_to_tensor(
              values, as_ref=input_arg.is_ref).dtype.name
        except ValueError as err:
          raise ValueError(
              f"Tried to convert '{input_name}' to a tensor and failed. "
              f"Error: {err}")
        prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                  (input_name, op_type_name, observed))
        if input_arg.type != types_pb2.DT_INVALID:
          raise TypeError(f"{prefix} expected type of "
                          f"{dtypes.as_dtype(input_arg.type).name}.")
        else:
          # Update the maps with the default, if needed.
          k = input_arg.type_attr
          if k in default_type_attr_map:
            if k not in attrs:
              attrs[k] = default_type_attr_map[k]
              if k not in inferred_from:
                inferred_from[k] = "Default in OpDef"

          raise TypeError(
              f"{prefix} type "
              f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
              f"argument '{inferred_from[input_arg.type_attr]}'.")

      types = [values.dtype]
      inputs.append(values)
    base_types = [x.base_dtype for x in types]

    if input_arg.number_attr:
      # <number-attr> * <type> or <number-attr> * <type-attr>
      if input_arg.number_attr in attrs:
        if len(values) != attrs[input_arg.number_attr]:
          raise ValueError(
              f"List argument '{input_name}' to '{op_type_name}' Op with "
              f"length {len(values)} must match length "
              f"{attrs[input_arg.number_attr]} of argument "
              f"'{inferred_from[input_arg.number_attr]}'.")
      else:
        attrs[input_arg.number_attr] = len(values)
        inferred_from[input_arg.number_attr] = input_name
        num_attr = _Attr(op_def, input_arg.number_attr)
        if num_attr.has_minimum and len(values) < num_attr.minimum:
          raise ValueError(
              f"List argument '{input_name}' to '{op_type_name}' Op with "
              f"length {len(values)} shorter than minimum length "
              f"{num_attr.minimum}.")
      # All tensors must have the same base type.
      if any(bt != base_types[0] for bt in base_types):
        raise TypeError(
            f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
            f"must have the same type. Got {base_types} instead.")
      if input_arg.type != types_pb2.DT_INVALID:
        # <number-attr> * <type> case
        if base_types and base_types[0] != input_arg.type:
          assert False, "Unreachable"
      elif input_arg.type_attr in attrs:
        # <number-attr> * <type-attr> case, where <type-attr> already
        # has an inferred value.
        if base_types and base_types[0] != attrs[input_arg.type_attr]:
          assert False, "Unreachable"
      else:
        # <number-attr> * <type-attr> case, where we are now setting
        # the <type-attr> based on this input
        if not base_types:
          # If it's in default_type_attr_map, then wait to set it
          # (in "process remaining attrs", below).
          if input_arg.type_attr not in default_type_attr_map:
            raise TypeError(
                "Don't know how to infer type variable from empty input "
                f"list passed to input '{input_name}' of '{op_type_name}' "
                "Op.")
        else:
          attrs[input_arg.type_attr] = base_types[0]
          inferred_from[input_arg.type_attr] = input_name
          type_attr = _Attr(op_def, input_arg.type_attr)
          _SatisfiesTypeConstraint(
              base_types[0], type_attr, param_name=input_name)
    elif input_arg.type_attr:
      # <type-attr>
      attr_value = base_types[0]
      if input_arg.type_attr in attrs:
        if attrs[input_arg.type_attr] != attr_value:
          raise TypeError(
              f"Input '{input_name}' of '{op_type_name}' Op has type "
              f"{dtypes.as_dtype(attr_value).name} that does not match type "
              f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
              f"argument '{inferred_from[input_arg.type_attr]}'.")
      else:
        for base_type in base_types:
          _SatisfiesTypeConstraint(
              base_type,
              _Attr(op_def, input_arg.type_attr),
              param_name=input_name)
        attrs[input_arg.type_attr] = attr_value
        inferred_from[input_arg.type_attr] = input_name
    elif input_arg.type_list_attr:
      # <type-list-attr>
      attr_value = base_types
      if input_arg.type_list_attr in attrs:
        if attrs[input_arg.type_list_attr] != attr_value:
          actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value)
          expected_types = ", ".join(
              dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr])
          raise TypeError(
              f"Input '{input_name}' of '{op_type_name}' Op has type list of "
              f"{actual_types} that does not match type list {expected_types}"
              f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
      else:
        for base_type in base_types:
          _SatisfiesTypeConstraint(
              base_type,
              _Attr(op_def, input_arg.type_list_attr),
              param_name=input_name)
        attrs[input_arg.type_list_attr] = attr_value
        inferred_from[input_arg.type_list_attr] = input_name
    else:
      # single Tensor with specified type
      if base_types[0] != input_arg.type:
        assert False, "Unreachable"

    if input_arg.is_ref:
      if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
        raise TypeError(
            f"'{op_type_name}' Op requires that input '{input_name}' be a "
            "mutable tensor (e.g.: a tf.Variable)")
      input_types.extend(types)
    else:
      input_types.extend(base_types)