Beispiel #1
0
def _constant_impl(
    value, dtype, shape, name, verify_shape, allow_broadcast):
  """Implementation of constant."""
  ctx = context.context()
  if ctx.executing_eagerly():
    if trace.enabled:
      with trace.Trace("tf.constant"):
        return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
    return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)

  g = ops.get_default_graph()
  tensor_value = attr_value_pb2.AttrValue()
  tensor_value.tensor.CopyFrom(
      tensor_util.make_tensor_proto(
          value, dtype=dtype, shape=shape, verify_shape=verify_shape,
          allow_broadcast=allow_broadcast))
  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
  attrs = {"value": tensor_value, "dtype": dtype_value}
  const_tensor = g._create_op_internal(  # pylint: disable=protected-access
      "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0]

  if op_callbacks.should_invoke_op_callbacks():
    # TODO(b/147670703): Once the special-op creation code paths
    # are unified. Remove this `if` block.
    callback_outputs = op_callbacks.invoke_op_callbacks(
        "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g)
    if callback_outputs is not None:
      const_tensor, = callback_outputs
  return const_tensor
def graph_placeholder(dtype, shape, name=None):
    """Graph-only version of tf.compat.v1.placeholder(), for internal use only."""
    dtype = dtype.base_dtype
    dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
    if isinstance(shape, (list, tuple)):
        shape = tensor_shape.TensorShape(shape)
    shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
    g = ops.get_default_graph()
    attrs = {"dtype": dtype_value, "shape": shape}
    op = g._create_op_internal(  # pylint: disable=protected-access
        "Placeholder", [], [dtype],
        input_types=[],
        attrs=attrs,
        name=name)
    result, = op.outputs
    if op_callbacks.should_invoke_op_callbacks():
        # TODO(b/147670703): Once the special-op creation code paths
        # are unified. Remove this `if` block.
        callback_outputs = op_callbacks.invoke_op_callbacks("Placeholder",
                                                            tuple(),
                                                            attrs,
                                                            tuple(op.outputs),
                                                            op_name=name,
                                                            graph=g)
        if callback_outputs is not None:
            result, = callback_outputs
    return result
Beispiel #3
0
def _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast):
    """Implementation of constant."""
    ctx = context.context()
    if ctx.executing_eagerly():
        t = convert_to_eager_tensor(value, ctx, dtype)
        if shape is None:
            return t
        shape = tensor_shape.as_shape(shape)
        if shape == t.shape:
            return t
        if verify_shape:
            raise TypeError("Expected Tensor's shape: %s, got %s." %
                            (tuple(shape), tuple(t.shape)))
        num_t = t.shape.num_elements()
        # TODO(josh11b): Implement shape -> eager tensor conversion.
        if num_t == shape.num_elements():
            return _eager_reshape(t, shape.as_list(), ctx)
        if num_t == 1:
            if t.dtype == dtypes.bool:
                # We don't have a Fill kernel for bool dtype on GPU. So we first run
                # Fill on CPU and then copy to GPU if needed.
                with ops.device("/device:CPU:0"):
                    x = _eager_fill(shape.as_list(), _eager_identity(t, ctx),
                                    ctx)
                return _eager_identity(x, ctx)
            else:
                return _eager_fill(shape.as_list(), t, ctx)
        raise TypeError(
            "Eager execution of tf.constant with unsupported shape "
            "(value has %d elements, shape is %s with %d elements)." %
            (num_t, shape, shape.num_elements()))
    g = ops.get_default_graph()
    tensor_value = attr_value_pb2.AttrValue()
    tensor_value.tensor.CopyFrom(
        tensor_util.make_tensor_proto(value,
                                      dtype=dtype,
                                      shape=shape,
                                      verify_shape=verify_shape,
                                      allow_broadcast=allow_broadcast))
    dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
    attrs = {"value": tensor_value, "dtype": dtype_value}
    const_tensor = g._create_op_internal(  # pylint: disable=protected-access
        "Const", [], [dtype_value.type],
        attrs=attrs,
        name=name).outputs[0]

    if op_callbacks.should_invoke_op_callbacks():
        # TODO(b/147670703): Once the special-op creation code paths
        # are unified. Remove this `if` block.
        callback_outputs = op_callbacks.invoke_op_callbacks("Const",
                                                            tuple(),
                                                            attrs,
                                                            (const_tensor, ),
                                                            op_name=name,
                                                            graph=g)
        if callback_outputs is not None:
            const_tensor, = callback_outputs
    return const_tensor
Beispiel #4
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
    """Implementation of apply_op that returns output_structure, op."""
    op_def = op_def_registry.get(op_type_name)
    if op_def is None:
        raise RuntimeError("Unrecognized Op name " + op_type_name)

    # Determine the graph context.
    try:
        # Need to flatten all the arguments into a list.
        # pylint: disable=protected-access
        g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
        # pylint: enable=protected-access
    except AssertionError as e:
        raise RuntimeError("Cannot determine graph for Op '%s' due to: %s" %
                           (op_type_name, e.message))

    # Default name if not specified.
    if name is None:
        name = op_type_name

    # Check for deprecation
    deprecation_version = op_def.deprecation.version
    if deprecation_version:
        producer = g.graph_def_versions.producer
        if producer >= deprecation_version:
            raise NotImplementedError(
                ("Op %s is not available in GraphDef version %d. "
                 "It has been removed in version %d. %s.") %
                (op_type_name, producer, deprecation_version,
                 op_def.deprecation.explanation))

    # Fill in the list of default types for all "type" attrs.  This
    # will be used to choose a preferred dtype to convert to in the
    # absence of input type information.
    #
    # TODO(b/31302892): Currently the defaults don't work in the right
    # way if you have two inputs, one of whose type resolution depends
    # on the other.  Handling this will require restructuring this code
    # significantly.
    default_type_attr_map = {}
    for attr_def in op_def.attr:
        if attr_def.type != "type":
            continue
        key = attr_def.name
        if attr_def.HasField("default_value"):
            default_type_attr_map[key] = dtypes.as_dtype(
                attr_def.default_value.type)

    # Requires that op_def has passed validation (using the C++
    # ValidateOpDef() from ../framework/op_def_util.h).
    attrs = {}
    inputs = []
    input_types = []
    with g.as_default(), ops.name_scope(name) as scope:

        # Perform input type inference
        inferred_from = {}
        for input_arg in op_def.input_arg:
            input_name = input_arg.name
            if input_name in keywords:
                values = keywords.pop(input_name)
            elif input_name + "_" in keywords:
                # Handle the case where the name is a keyword or built-in
                # for Python so we use the name + _ instead.
                input_name += "_"
                values = keywords.pop(input_name)
            else:
                raise TypeError("No argument for input " + input_name)

            # Goals:
            # * Convert values to Tensors if it contains constants.
            # * Verify that values is a list if that matches the input_arg's
            #   type.
            # * If the input_arg's type is determined by attrs, either set
            #   those attrs and validate those attr values are legal (if
            #   they have not yet been set) or validate the input matches
            #   the type indicated by the attrs (if they have already been
            #   inferred via an earlier input).
            # * If the input_arg has an explicit type, make sure the input
            #   conforms.

            if _IsListParameter(input_arg):
                if not _IsListValue(values):
                    raise TypeError(
                        "Expected list for '%s' argument to '%s' Op, not %s." %
                        (input_name, op_type_name, values))
                # In cases where we expect all elements of the list to have the
                # same dtype, try to cast non-Tensor elements to that type.
                dtype = None
                default_dtype = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.number_attr:
                    if input_arg.type_attr in attrs:
                        dtype = attrs[input_arg.type_attr]
                    else:
                        for t in values:
                            if isinstance(t, ops.Tensor):
                                dtype = t.dtype
                                break

                    # dtype still not found, prefer using the default dtype
                    # from the attr.
                    if dtype is None and input_arg.type_attr in default_type_attr_map:
                        default_dtype = default_type_attr_map[
                            input_arg.type_attr]

                try:
                    if not input_arg.is_ref and dtype:
                        dtype = dtypes.as_dtype(dtype).base_dtype
                    values = ops.internal_convert_n_to_tensor(
                        values,
                        name=input_arg.name,
                        dtype=dtype if dtype else None,
                        preferred_dtype=default_dtype,
                        as_ref=input_arg.is_ref)
                    if input_arg.number_attr and len(
                            set(v.dtype.base_dtype for v in values)) > 1:
                        raise TypeError()  # All types should match.
                except (TypeError, ValueError):
                    # What types does the conversion function think values have?
                    observed_types = []
                    for value in values:
                        try:
                            converted_value = ops.internal_convert_to_tensor(
                                value, as_ref=input_arg.is_ref)
                            observed_types.append(
                                converted_value.dtype.base_dtype.name)
                        except (TypeError, ValueError):
                            observed_types.append(
                                "<NOT CONVERTIBLE TO TENSOR>")
                    observed = ", ".join(observed_types)

                    prefix = (
                        "Tensors in list passed to '%s' of '%s' Op have types [%s]"
                        % (input_name, op_type_name, observed))
                    if input_arg.number_attr:
                        if input_arg.type != types_pb2.DT_INVALID:
                            raise TypeError(
                                "%s that do not match expected type %s." %
                                (prefix, dtype.name))
                        elif input_arg.type_attr in attrs:
                            raise TypeError(
                                "%s that do not match type %s inferred from "
                                "earlier arguments." % (prefix, dtype.name))
                        else:
                            raise TypeError("%s that don't all match." %
                                            prefix)
                    else:
                        raise TypeError("%s that are invalid. Tensors: %s" %
                                        (prefix, values))

                types = [x.dtype for x in values]
                inputs.extend(values)
            else:
                # In cases where we have an expected type, try to convert non-Tensor
                # arguments to that type.
                dtype = None
                default_dtype = None
                if input_arg.type != types_pb2.DT_INVALID:
                    dtype = input_arg.type
                elif input_arg.type_attr in attrs:
                    dtype = attrs[input_arg.type_attr]
                elif input_arg.type_attr in default_type_attr_map:
                    # The dtype could not be inferred solely from the inputs,
                    # so we prefer the attr's default, so code that adds a new attr
                    # with a default is backwards compatible.
                    default_dtype = default_type_attr_map[input_arg.type_attr]

                try:
                    values = ops.internal_convert_to_tensor(
                        values,
                        name=input_arg.name,
                        dtype=dtype,
                        as_ref=input_arg.is_ref,
                        preferred_dtype=default_dtype)
                except TypeError as err:
                    if dtype is None:
                        raise err
                    else:
                        raise TypeError(
                            "Expected %s passed to parameter '%s' of op '%s', got %s of "
                            "type '%s' instead. Error: %s" %
                            (dtypes.as_dtype(dtype).name,
                             input_arg.name, op_type_name, repr(values),
                             type(values).__name__, err))
                except ValueError:
                    # What type does convert_to_tensor think it has?
                    try:
                        observed = ops.internal_convert_to_tensor(
                            values, as_ref=input_arg.is_ref).dtype.name
                    except ValueError as err:
                        raise ValueError(
                            "Tried to convert '%s' to a tensor and failed. Error: %s"
                            % (input_name, err))
                    prefix = (
                        "Input '%s' of '%s' Op has type %s that does not match"
                        % (input_name, op_type_name, observed))
                    if input_arg.type != types_pb2.DT_INVALID:
                        raise TypeError(
                            "%s expected type of %s." %
                            (prefix, dtypes.as_dtype(input_arg.type).name))
                    else:
                        # Update the maps with the default, if needed.
                        k = input_arg.type_attr
                        if k in default_type_attr_map:
                            if k not in attrs:
                                attrs[k] = default_type_attr_map[k]
                                if k not in inferred_from:
                                    inferred_from[k] = "Default in OpDef"

                        raise TypeError(
                            "%s type %s of argument '%s'." %
                            (prefix, dtypes.as_dtype(
                                attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))

                types = [values.dtype]
                inputs.append(values)
            base_types = [x.base_dtype for x in types]

            if input_arg.number_attr:
                # <number-attr> * <type> or <number-attr> * <type-attr>
                if input_arg.number_attr in attrs:
                    if len(values) != attrs[input_arg.number_attr]:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d must match "
                            "length %d of argument '%s'." %
                            (input_name, op_type_name, len(values),
                             attrs[input_arg.number_attr],
                             inferred_from[input_arg.number_attr]))
                else:
                    attrs[input_arg.number_attr] = len(values)
                    inferred_from[input_arg.number_attr] = input_name
                    num_attr = _Attr(op_def, input_arg.number_attr)
                    if num_attr.has_minimum and len(values) < num_attr.minimum:
                        raise ValueError(
                            "List argument '%s' to '%s' Op with length %d shorter "
                            "than minimum length %d." %
                            (input_name, op_type_name, len(values),
                             num_attr.minimum))
                # All tensors must have the same base type.
                if any(bt != base_types[0] for bt in base_types):
                    raise TypeError("All tensors passed to '%s' of '%s' Op "
                                    "must have the same type." %
                                    (input_name, op_type_name))
                if input_arg.type != types_pb2.DT_INVALID:
                    # <number-attr> * <type> case
                    if base_types and base_types[0] != input_arg.type:
                        assert False, "Unreachable"
                elif input_arg.type_attr in attrs:
                    # <number-attr> * <type-attr> case, where <type-attr> already
                    # has an inferred value.
                    if base_types and base_types[0] != attrs[
                            input_arg.type_attr]:
                        assert False, "Unreachable"
                else:
                    # <number-attr> * <type-attr> case, where we are now setting
                    # the <type-attr> based on this input
                    if not base_types:
                        raise TypeError(
                            "Don't know how to infer type variable from empty input "
                            "list passed to input '%s' of '%s' Op." %
                            (input_name, op_type_name))
                    attrs[input_arg.type_attr] = base_types[0]
                    inferred_from[input_arg.type_attr] = input_name
                    type_attr = _Attr(op_def, input_arg.type_attr)
                    _SatisfiesTypeConstraint(base_types[0],
                                             type_attr,
                                             param_name=input_name)
            elif input_arg.type_attr:
                # <type-attr>
                attr_value = base_types[0]
                if input_arg.type_attr in attrs:
                    if attrs[input_arg.type_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type %s that does not "
                            "match type %s of argument '%s'." %
                            (input_name, op_type_name,
                             dtypes.as_dtype(attr_value).name,
                             dtypes.as_dtype(attrs[input_arg.type_attr]).name,
                             inferred_from[input_arg.type_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(op_def,
                                                       input_arg.type_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_attr] = attr_value
                    inferred_from[input_arg.type_attr] = input_name
            elif input_arg.type_list_attr:
                # <type-list-attr>
                attr_value = base_types
                if input_arg.type_list_attr in attrs:
                    if attrs[input_arg.type_list_attr] != attr_value:
                        raise TypeError(
                            "Input '%s' of '%s' Op has type list of %s that does not "
                            "match type list %s of argument '%s'." %
                            (input_name, op_type_name, ", ".join(
                                dtypes.as_dtype(x).name
                                for x in attr_value), ", ".join(
                                    dtypes.as_dtype(x).name
                                    for x in attrs[input_arg.type_list_attr]),
                             inferred_from[input_arg.type_list_attr]))
                else:
                    for base_type in base_types:
                        _SatisfiesTypeConstraint(base_type,
                                                 _Attr(
                                                     op_def,
                                                     input_arg.type_list_attr),
                                                 param_name=input_name)
                    attrs[input_arg.type_list_attr] = attr_value
                    inferred_from[input_arg.type_list_attr] = input_name
            else:
                # single Tensor with specified type
                if base_types[0] != input_arg.type:
                    assert False, "Unreachable"

            if input_arg.is_ref:
                if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
                    raise TypeError((
                        "'%s' Op requires that input '%s' be a mutable tensor "
                        "(e.g.: a tf.Variable)") % (op_type_name, input_name))
                input_types.extend(types)
            else:
                input_types.extend(base_types)

        # Process remaining attrs
        for attr in op_def.attr:
            # Skip attrs that have already had their values inferred
            if attr.name in attrs:
                if attr.name in keywords:
                    raise TypeError(
                        "Should not specify value for inferred attr '%s'." %
                        attr.name)
                continue
            if attr.name in keywords:
                attrs[attr.name] = keywords.pop(attr.name)
            elif attr.name + "_" in keywords:
                # Attrs whose names match Python keywords have an extra '_'
                # appended, so we must check for that as well.
                attrs[attr.name] = keywords.pop(attr.name + "_")
            else:
                raise TypeError("No argument for attr " + attr.name)

        # Convert attr values to AttrValue protos.
        attr_protos = {}
        for attr_def in op_def.attr:
            key = attr_def.name
            value = attrs[key]
            attr_value = attr_value_pb2.AttrValue()
            if attr_def.HasField("default_value") and value is None:
                attr_value.CopyFrom(attr_def.default_value)
                attr_protos[key] = attr_value
                continue
            if attr_def.type.startswith("list("):
                if not _IsListValue(value):
                    raise TypeError("Expected list for attr " + key)
                if attr_def.has_minimum:
                    if len(value) < attr_def.minimum:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed list of length %d "
                            "less than minimum %d." %
                            (key, op_type_name, len(value), attr_def.minimum))
                attr_value.list.SetInParent()
            if attr_def.type == "string":
                attr_value.s = _MakeStr(value, key)
                if attr_def.HasField("allowed_values"):
                    if attr_value.s not in attr_def.allowed_values.list.s:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                            % (key, op_type_name, compat.as_text(
                                attr_value.s), '", "'.join(
                                    map(compat.as_text,
                                        attr_def.allowed_values.list.s))))
            elif attr_def.type == "list(string)":
                attr_value.list.s.extend([_MakeStr(x, key) for x in value])
                if attr_def.HasField("allowed_values"):
                    for x in attr_value.list.s:
                        if x not in attr_def.allowed_values.list.s:
                            raise ValueError(
                                "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"."
                                % (key, op_type_name, compat.as_text(x),
                                   '", "'.join(
                                       map(compat.as_text,
                                           attr_def.allowed_values.list.s))))
            elif attr_def.type == "int":
                attr_value.i = _MakeInt(value, key)
                if attr_def.has_minimum:
                    if attr_value.i < attr_def.minimum:
                        raise ValueError(
                            "Attr '%s' of '%s' Op passed %d less than minimum %d."
                            % (key, op_type_name, attr_value.i,
                               attr_def.minimum))
            elif attr_def.type == "list(int)":
                attr_value.list.i.extend([_MakeInt(x, key) for x in value])
            elif attr_def.type == "float":
                attr_value.f = _MakeFloat(value, key)
            elif attr_def.type == "list(float)":
                attr_value.list.f.extend([_MakeFloat(x, key) for x in value])
            elif attr_def.type == "bool":
                attr_value.b = _MakeBool(value, key)
            elif attr_def.type == "list(bool)":
                attr_value.list.b.extend([_MakeBool(x, key) for x in value])
            elif attr_def.type == "type":
                attr_value.type = _MakeType(value, attr_def)
            elif attr_def.type == "list(type)":
                attr_value.list.type.extend(
                    [_MakeType(x, attr_def) for x in value])
            elif attr_def.type == "shape":
                attr_value.shape.CopyFrom(_MakeShape(value, key))
            elif attr_def.type == "list(shape)":
                attr_value.list.shape.extend(
                    [_MakeShape(x, key) for x in value])
            elif attr_def.type == "tensor":
                attr_value.tensor.CopyFrom(_MakeTensor(value, key))
            elif attr_def.type == "list(tensor)":
                attr_value.list.tensor.extend(
                    [_MakeTensor(x, key) for x in value])
            elif attr_def.type == "func":
                attr_value.func.CopyFrom(_MakeFunc(value, key))
            elif attr_def.type == "list(func)":
                attr_value.list.func.extend([_MakeFunc(x, key) for x in value])
            else:
                raise TypeError("Unrecognized Attr type " + attr_def.type)

            attr_protos[key] = attr_value
        del attrs  # attrs is no longer authoritative, use attr_protos instead

        # Determine output types (possibly using attrs)
        output_structure = []
        for arg in op_def.output_arg:
            if arg.number_attr:
                n = _AttrValue(attr_protos, arg.number_attr).i
                output_structure.append(n)
            elif arg.type_attr:
                t = _AttrValue(attr_protos, arg.type_attr)
                output_structure.append(None)
            elif arg.type_list_attr:
                t = _AttrValue(attr_protos, arg.type_list_attr)
                output_structure.append(len(t.list.type))
            else:
                output_structure.append(None)

        if keywords:
            raise TypeError("apply_op() got unexpected keyword arguments: " +
                            ", ".join(sorted(keywords.keys())))

        # NOTE(mrry): We add an explicit colocation constraint between
        # the newly created op and any of its reference-typed inputs.
        must_colocate_inputs = [
            val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref
        ]
        with _MaybeColocateWith(must_colocate_inputs):
            # Add Op to graph
            # pylint: disable=protected-access
            op = g._create_op_internal(op_type_name,
                                       inputs,
                                       dtypes=None,
                                       name=scope,
                                       input_types=input_types,
                                       attrs=attr_protos,
                                       op_def=op_def)

        # `outputs` is returned as a separate return value so that the output
        # tensors can the `op` per se can be decoupled so that the
        # `op_callbacks` can function properly. See framework/op_callbacks.py
        # for more details.
        outputs = op.outputs
        # Conditionally invoke tfdbg v2's op callback(s).
        if op_callbacks.should_invoke_op_callbacks():
            callback_outputs = op_callbacks.invoke_op_callbacks(
                op.node_def.op,
                tuple(op.inputs),
                attr_protos,
                tuple(outputs),
                op_name=op.name,
                graph=g)
            if callback_outputs is not None:
                outputs = callback_outputs

        return output_structure, op_def.is_stateful, op, outputs
Beispiel #5
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
  """Implementation of apply_op that returns output_structure, op."""
  op_def = op_def_registry.get(op_type_name)
  if op_def is None:
    raise RuntimeError(f"Unrecognized Op name {op_type_name}")

  # Determine the graph context.
  try:
    # Need to flatten all the arguments into a list.
    # pylint: disable=protected-access
    g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    # pylint: enable=protected-access
  except AssertionError as e:
    raise RuntimeError(
        f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}")

  # Default name if not specified.
  if name is None:
    name = op_type_name

  # Check for deprecation
  deprecation_version = op_def.deprecation.version
  if deprecation_version:
    producer = g.graph_def_versions.producer
    if producer >= deprecation_version:
      raise NotImplementedError(
          f"Op {op_type_name} is not available in GraphDef version {producer}. "
          f"It has been removed in version {deprecation_version}. "
          f"{op_def.deprecation.explanation}.")

  # Fill in the list of default types for all "type" attrs.  This
  # will be used to choose a preferred dtype to convert to in the
  # absence of input type information.
  #
  # TODO(b/31302892): Currently the defaults don't work in the right
  # way if you have two inputs, one of whose type resolution depends
  # on the other.  Handling this will require restructuring this code
  # significantly.
  default_type_attr_map = {}
  allowed_list_attr_map = {}
  for attr_def in op_def.attr:
    if attr_def.type != "type":
      continue
    key = attr_def.name
    if attr_def.HasField("default_value"):
      default_type_attr_map[key] = dtypes.as_dtype(
          attr_def.default_value.type)
    if attr_def.HasField("allowed_values"):
      allowed_list_attr_map[key] = attr_def.allowed_values.list.type

  # Requires that op_def has passed validation (using the C++
  # ValidateOpDef() from ../framework/op_def_util.h).
  attrs = {}
  inputs = []
  input_types = []
  with g.as_default(), ops.name_scope(name) as scope:

    # Perform input type inference
    inferred_from = {}
    for input_arg in op_def.input_arg:
      input_name = input_arg.name
      if input_name in keywords:
        values = keywords.pop(input_name)
      elif input_name + "_" in keywords:
        # Handle the case where the name is a keyword or built-in
        # for Python so we use the name + _ instead.
        input_name += "_"
        values = keywords.pop(input_name)
      else:
        raise TypeError(f"No argument for input {input_name} found in {op_def}")

      # Goals:
      # * Convert values to Tensors if it contains constants.
      # * Verify that values is a list if that matches the input_arg's
      #   type.
      # * If the input_arg's type is determined by attrs, either set
      #   those attrs and validate those attr values are legal (if
      #   they have not yet been set) or validate the input matches
      #   the type indicated by the attrs (if they have already been
      #   inferred via an earlier input).
      # * If the input_arg has an explicit type, make sure the input
      #   conforms.

      if _IsListParameter(input_arg):
        if not _IsListValue(values):
          raise TypeError(
              f"Expected list for '{input_name}' argument to '{op_type_name}' "
              f"Op, not {values}.")
        # In cases where we expect all elements of the list to have the
        # same dtype, try to cast non-Tensor elements to that type.
        dtype = None
        default_dtype = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.number_attr:
          if input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]
          else:
            for t in values:
              if isinstance(t, ops.Tensor):
                dtype = t.dtype
                break

          # dtype still not found, prefer using the default dtype
          # from the attr.
          if dtype is None and input_arg.type_attr in default_type_attr_map:
            default_dtype = default_type_attr_map[input_arg.type_attr]

        try:
          if not input_arg.is_ref and dtype:
            dtype = dtypes.as_dtype(dtype).base_dtype
          values = ops.internal_convert_n_to_tensor(
              values,
              name=input_arg.name,
              dtype=dtype if dtype else None,
              preferred_dtype=default_dtype,
              as_ref=input_arg.is_ref)
          all_types = set(v.dtype.base_dtype for v in values)
          if input_arg.number_attr and len(all_types) > 1:
            # All types should match.
            raise TypeError(f"Not all types matched for {input_arg.name} for "
                            f"{op_type_name}. Got {all_types}")
        except (TypeError, ValueError):
          # What types does the conversion function think values have?
          observed_types = []
          for value in values:
            try:
              converted_value = ops.convert_to_tensor(
                  value, as_ref=input_arg.is_ref)
              observed_types.append(converted_value.dtype.base_dtype.name)
            except (TypeError, ValueError):
              observed_types.append("<NOT CONVERTIBLE TO TENSOR>")
          observed = ", ".join(observed_types)

          prefix = (
              "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
              (input_name, op_type_name, observed))
          if input_arg.number_attr:
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError(f"{prefix} that do not match expected type "
                              f"{dtype.name}.")
            elif input_arg.type_attr in attrs:
              raise TypeError(f"{prefix} that do not match type {dtype.name} "
                              "inferred from earlier arguments.")
            else:
              raise TypeError(f"{prefix} that don't all match.")
          else:
            raise TypeError(f"{prefix} that are invalid. Tensors: {values}")

        types = [x.dtype for x in values]
        inputs.extend(values)
      else:
        # In cases where we have an expected type, try to convert non-Tensor
        # arguments to that type.
        dtype = None
        default_dtype = None
        allowed_list = None
        if input_arg.type != types_pb2.DT_INVALID:
          dtype = input_arg.type
        elif input_arg.type_attr in attrs:
          dtype = attrs[input_arg.type_attr]
        elif input_arg.type_attr in default_type_attr_map:
          # The dtype could not be inferred solely from the inputs,
          # so we prefer the attr's default, so code that adds a new attr
          # with a default is backwards compatible.
          default_dtype = default_type_attr_map[input_arg.type_attr]
          allowed_list = allowed_list_attr_map.get(input_arg.type_attr)

        try:
          # First see if we can get a valid dtype with the default conversion
          # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
          # not list allowed dtypes, in which case we should skip this.
          if dtype is None and allowed_list:
            inferred = None
            try:
              inferred = ops.convert_to_tensor(
                  values, name=input_arg.name, as_ref=input_arg.is_ref)
            except TypeError as err:
              # When converting a python object such as a list of Dimensions, we
              # need a dtype to be specified, thus tensor conversion may throw
              # an exception which we will ignore and try again below.
              pass

            # If we did not match an allowed dtype, try again with the default
            # dtype. This could be because we have an empty tensor and thus we
            # picked the wrong type.
            if inferred is not None and inferred.dtype in allowed_list:
              values = inferred
            else:
              values = ops.convert_to_tensor(
                  values,
                  name=input_arg.name,
                  as_ref=input_arg.is_ref,
                  preferred_dtype=default_dtype)
          else:
            values = ops.convert_to_tensor(
                values,
                name=input_arg.name,
                dtype=dtype,
                as_ref=input_arg.is_ref,
                preferred_dtype=default_dtype)
        except TypeError as err:
          if dtype is None:
            raise err
          else:
            raise TypeError(
                f"Expected {dtypes.as_dtype(dtype).name} passed to parameter "
                f"'{input_arg.name}' of op '{op_type_name}', got "
                f"{repr(values)} of type '{type(values).__name__}' instead. "
                f"Error: {err}")
        except ValueError:
          # What type does convert_to_tensor think it has?
          try:
            observed = ops.convert_to_tensor(
                values, as_ref=input_arg.is_ref).dtype.name
          except ValueError as err:
            raise ValueError(
                f"Tried to convert '{input_name}' to a tensor and failed. "
                f"Error: {err}")
          prefix = ("Input '%s' of '%s' Op has type %s that does not match" %
                    (input_name, op_type_name, observed))
          if input_arg.type != types_pb2.DT_INVALID:
            raise TypeError(f"{prefix} expected type of "
                            f"{dtypes.as_dtype(input_arg.type).name}.")
          else:
            # Update the maps with the default, if needed.
            k = input_arg.type_attr
            if k in default_type_attr_map:
              if k not in attrs:
                attrs[k] = default_type_attr_map[k]
                if k not in inferred_from:
                  inferred_from[k] = "Default in OpDef"

            raise TypeError(
                f"{prefix} type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")

        types = [values.dtype]
        inputs.append(values)
      base_types = [x.base_dtype for x in types]

      if input_arg.number_attr:
        # <number-attr> * <type> or <number-attr> * <type-attr>
        if input_arg.number_attr in attrs:
          if len(values) != attrs[input_arg.number_attr]:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} must match length "
                f"{attrs[input_arg.number_attr]} of argument "
                f"'{inferred_from[input_arg.number_attr]}'.")
        else:
          attrs[input_arg.number_attr] = len(values)
          inferred_from[input_arg.number_attr] = input_name
          num_attr = _Attr(op_def, input_arg.number_attr)
          if num_attr.has_minimum and len(values) < num_attr.minimum:
            raise ValueError(
                f"List argument '{input_name}' to '{op_type_name}' Op with "
                f"length {len(values)} shorter than minimum length "
                f"{num_attr.minimum}.")
        # All tensors must have the same base type.
        if any(bt != base_types[0] for bt in base_types):
          raise TypeError(
              f"All tensors passed to '{input_name}' of '{op_type_name}' Op "
              f"must have the same type. Got {base_types} instead.")
        if input_arg.type != types_pb2.DT_INVALID:
          # <number-attr> * <type> case
          if base_types and base_types[0] != input_arg.type:
            assert False, "Unreachable"
        elif input_arg.type_attr in attrs:
          # <number-attr> * <type-attr> case, where <type-attr> already
          # has an inferred value.
          if base_types and base_types[0] != attrs[input_arg.type_attr]:
            assert False, "Unreachable"
        else:
          # <number-attr> * <type-attr> case, where we are now setting
          # the <type-attr> based on this input
          if not base_types:
            # If it's in default_type_attr_map, then wait to set it
            # (in "process remaining attrs", below).
            if input_arg.type_attr not in default_type_attr_map:
              raise TypeError(
                  "Don't know how to infer type variable from empty input "
                  f"list passed to input '{input_name}' of '{op_type_name}' "
                  "Op.")
          else:
            attrs[input_arg.type_attr] = base_types[0]
            inferred_from[input_arg.type_attr] = input_name
            type_attr = _Attr(op_def, input_arg.type_attr)
            _SatisfiesTypeConstraint(base_types[0], type_attr,
                                     param_name=input_name)
      elif input_arg.type_attr:
        # <type-attr>
        attr_value = base_types[0]
        if input_arg.type_attr in attrs:
          if attrs[input_arg.type_attr] != attr_value:
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type "
                f"{dtypes.as_dtype(attr_value).name} that does not match type "
                f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of "
                f"argument '{inferred_from[input_arg.type_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_attr),
                                     param_name=input_name)
          attrs[input_arg.type_attr] = attr_value
          inferred_from[input_arg.type_attr] = input_name
      elif input_arg.type_list_attr:
        # <type-list-attr>
        attr_value = base_types
        if input_arg.type_list_attr in attrs:
          if attrs[input_arg.type_list_attr] != attr_value:
            actual_types = ", ".join(
                dtypes.as_dtype(x).name for x in attr_value)
            expected_types = ", ".join(
                dtypes.as_dtype(x).name
                for x in attrs[input_arg.type_list_attr])
            raise TypeError(
                f"Input '{input_name}' of '{op_type_name}' Op has type list of "
                f"{actual_types} that does not match type list {expected_types}"
                f" of argument '{inferred_from[input_arg.type_list_attr]}'.")
        else:
          for base_type in base_types:
            _SatisfiesTypeConstraint(base_type,
                                     _Attr(op_def, input_arg.type_list_attr),
                                     param_name=input_name)
          attrs[input_arg.type_list_attr] = attr_value
          inferred_from[input_arg.type_list_attr] = input_name
      else:
        # single Tensor with specified type
        if base_types[0] != input_arg.type:
          assert False, "Unreachable"

      if input_arg.is_ref:
        if not all(x._is_ref_dtype for x in types):  # pylint: disable=protected-access
          raise TypeError(
              f"'{op_type_name}' Op requires that input '{input_name}' be a "
              "mutable tensor (e.g.: a tf.Variable)")
        input_types.extend(types)
      else:
        input_types.extend(base_types)

    # Process remaining attrs
    for attr in op_def.attr:
      # Skip attrs that have already had their values inferred
      if attr.name in attrs:
        if attr.name in keywords:
          raise TypeError(
              f"Should not specify value for inferred attr '{attr.name}' for "
              f"{op_type_name}.")
        continue
      if attr.name in keywords:
        attrs[attr.name] = keywords.pop(attr.name)
      elif attr.name + "_" in keywords:
        # Attrs whose names match Python keywords have an extra '_'
        # appended, so we must check for that as well.
        attrs[attr.name] = keywords.pop(attr.name + "_")
      elif attr.name in default_type_attr_map:
        attrs[attr.name] = default_type_attr_map[attr.name]
        inferred_from.setdefault(attr.name, "Default in OpDef")
      else:
        raise TypeError(f"No argument found for attr {attr.name} for "
                        f"{op_type_name}")

    # Convert attr values to AttrValue protos.
    attr_protos = {}
    for attr_def in op_def.attr:
      key = attr_def.name
      value = attrs[key]

      if attr_def.HasField("default_value") and value is None:
        attr_value = attr_value_pb2.AttrValue()
        attr_value.CopyFrom(attr_def.default_value)
        attr_protos[key] = attr_value
        continue

      attr_value = value_to_attr_value(value, attr_def.type, key)
      if attr_def.type.startswith("list("):
        _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name)
      if attr_def.HasField("allowed_values"):
        if attr_def.type == "string":
          _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key,
                                             op_type_name)
        elif attr_def.type == "list(string)":
          for value in attr_value.list.s:
            _SatisfiesAllowedStringsConstraint(value, attr_def, key,
                                               op_type_name)
      if attr_def.has_minimum and attr_def.type == "int":
        _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key,
                                       op_type_name)
      if attr_def.type == "type":
        _SatisfiesTypeConstraint(attr_value.type, attr_def, key)
      if attr_def.type == "list(type)":
        for value in attr_value.list.type:
          _SatisfiesTypeConstraint(value, attr_def, key)

      attr_protos[key] = attr_value
    del attrs  # attrs is no longer authoritative, use attr_protos instead

    # Determine output types (possibly using attrs)
    output_structure = []
    for arg in op_def.output_arg:
      if arg.number_attr:
        n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i
        output_structure.append(n)
      elif arg.type_attr:
        t = _AttrValue(attr_protos, arg.type_attr, op_type_name)
        output_structure.append(None)
      elif arg.type_list_attr:
        t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name)
        output_structure.append(len(t.list.type))
      else:
        output_structure.append(None)

    if keywords:
      all_keywords = ", ".join(sorted(keywords.keys()))
      raise TypeError(f"{op_type_name} got unexpected keyword arguments: "
                      f"{all_keywords}.")

    # NOTE(mrry): We add an explicit colocation constraint between
    # the newly created op and any of its reference-typed inputs.
    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                            if arg.is_ref]
    with _MaybeColocateWith(must_colocate_inputs):
      # Add Op to graph
      # pylint: disable=protected-access
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
                                 name=scope, input_types=input_types,
                                 attrs=attr_protos, op_def=op_def)

    # `outputs` is returned as a separate return value so that the output
    # tensors can the `op` per se can be decoupled so that the
    # `op_callbacks` can function properly. See framework/op_callbacks.py
    # for more details.
    outputs = op.outputs
    # Conditionally invoke tfdbg v2's op callback(s).
    if op_callbacks.should_invoke_op_callbacks():
      callback_outputs = op_callbacks.invoke_op_callbacks(
          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
          op_name=op.name, graph=g)
      if callback_outputs is not None:
        outputs = callback_outputs

    return output_structure, op_def.is_stateful, op, outputs
Beispiel #6
0
def _apply_op_helper(op_type_name, name=None, **keywords):  # pylint: disable=invalid-name
  """Implementation of apply_op that returns output_structure, op."""

  op_def, g, producer = _GetOpDef(op_type_name, keywords)
  name = name if name else op_type_name

  attrs, attr_protos = {}, {}
  default_type_attr_map, allowed_list_attr_map = {}, {}
  inputs, input_types, output_structure = [], [], []
  fallback = True

  if (_CanExtractAttrsFastPath(op_def, keywords) and
      flags.config().graph_building_optimization.value()):
    fallback = False
    attr_protos, inputs, input_types, output_structure = (
        op_def_library_pybind.process_inputs(op_type_name, producer, keywords))

  if fallback:
    _CheckOpDeprecation(op_type_name, op_def, producer)
    _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map,
                                        allowed_list_attr_map)

  # Requires that op_def has passed validation (using the C++
  # ValidateOpDef() from ../framework/op_def_util.h).
  with g.as_default(), ops.name_scope(name) as scope:
    if fallback:
      _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map,
                             keywords, default_type_attr_map, attrs, inputs,
                             input_types)
      _ExtractRemainingAttrs(op_type_name, op_def, keywords,
                             default_type_attr_map, attrs)
      _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos)
      del attrs  # attrs is no longer authoritative, use attr_protos instead
      _ExtractOutputStructure(op_type_name, op_def, attr_protos,
                              output_structure)
      _CheckAllInputsUsed(op_type_name, keywords)

    # NOTE(mrry): We add an explicit colocation constraint between
    # the newly created op and any of its reference-typed inputs.
    must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs)
                            if arg.is_ref]
    with _MaybeColocateWith(must_colocate_inputs):
      # Add Op to graph
      # pylint: disable=protected-access
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
                                 name=scope, input_types=input_types,
                                 attrs=attr_protos, op_def=op_def)

    # `outputs` is returned as a separate return value so that the output
    # tensors can the `op` per se can be decoupled so that the
    # `op_callbacks` can function properly. See framework/op_callbacks.py
    # for more details.
    outputs = op.outputs
    # Conditionally invoke tfdbg v2's op callback(s).
    if op_callbacks.should_invoke_op_callbacks():
      callback_outputs = op_callbacks.invoke_op_callbacks(
          op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs),
          op_name=op.name, graph=g)
      if callback_outputs is not None:
        outputs = callback_outputs

    return output_structure, op_def.is_stateful, op, outputs