def _constant_impl( value, dtype, shape, name, verify_shape, allow_broadcast): """Implementation of constant.""" ctx = context.context() if ctx.executing_eagerly(): if trace.enabled: with trace.Trace("tf.constant"): return _constant_eager_impl(ctx, value, dtype, shape, verify_shape) return _constant_eager_impl(ctx, value, dtype, shape, verify_shape) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto( value, dtype=dtype, shape=shape, verify_shape=verify_shape, allow_broadcast=allow_broadcast)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) attrs = {"value": tensor_value, "dtype": dtype_value} const_tensor = g._create_op_internal( # pylint: disable=protected-access "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0] if op_callbacks.should_invoke_op_callbacks(): # TODO(b/147670703): Once the special-op creation code paths # are unified. Remove this `if` block. callback_outputs = op_callbacks.invoke_op_callbacks( "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g) if callback_outputs is not None: const_tensor, = callback_outputs return const_tensor
def graph_placeholder(dtype, shape, name=None): """Graph-only version of tf.compat.v1.placeholder(), for internal use only.""" dtype = dtype.base_dtype dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum) if isinstance(shape, (list, tuple)): shape = tensor_shape.TensorShape(shape) shape = attr_value_pb2.AttrValue(shape=shape.as_proto()) g = ops.get_default_graph() attrs = {"dtype": dtype_value, "shape": shape} op = g._create_op_internal( # pylint: disable=protected-access "Placeholder", [], [dtype], input_types=[], attrs=attrs, name=name) result, = op.outputs if op_callbacks.should_invoke_op_callbacks(): # TODO(b/147670703): Once the special-op creation code paths # are unified. Remove this `if` block. callback_outputs = op_callbacks.invoke_op_callbacks("Placeholder", tuple(), attrs, tuple(op.outputs), op_name=name, graph=g) if callback_outputs is not None: result, = callback_outputs return result
def _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast): """Implementation of constant.""" ctx = context.context() if ctx.executing_eagerly(): t = convert_to_eager_tensor(value, ctx, dtype) if shape is None: return t shape = tensor_shape.as_shape(shape) if shape == t.shape: return t if verify_shape: raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape), tuple(t.shape))) num_t = t.shape.num_elements() # TODO(josh11b): Implement shape -> eager tensor conversion. if num_t == shape.num_elements(): return _eager_reshape(t, shape.as_list(), ctx) if num_t == 1: if t.dtype == dtypes.bool: # We don't have a Fill kernel for bool dtype on GPU. So we first run # Fill on CPU and then copy to GPU if needed. with ops.device("/device:CPU:0"): x = _eager_fill(shape.as_list(), _eager_identity(t, ctx), ctx) return _eager_identity(x, ctx) else: return _eager_fill(shape.as_list(), t, ctx) raise TypeError( "Eager execution of tf.constant with unsupported shape " "(value has %d elements, shape is %s with %d elements)." % (num_t, shape, shape.num_elements())) g = ops.get_default_graph() tensor_value = attr_value_pb2.AttrValue() tensor_value.tensor.CopyFrom( tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape, verify_shape=verify_shape, allow_broadcast=allow_broadcast)) dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) attrs = {"value": tensor_value, "dtype": dtype_value} const_tensor = g._create_op_internal( # pylint: disable=protected-access "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0] if op_callbacks.should_invoke_op_callbacks(): # TODO(b/147670703): Once the special-op creation code paths # are unified. Remove this `if` block. callback_outputs = op_callbacks.invoke_op_callbacks("Const", tuple(), attrs, (const_tensor, ), op_name=name, graph=g) if callback_outputs is not None: const_tensor, = callback_outputs return const_tensor
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError("Unrecognized Op name " + op_type_name) # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError("Cannot determine graph for Op '%s' due to: %s" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( ("Op %s is not available in GraphDef version %d. " "It has been removed in version %d. %s.") % (op_type_name, producer, deprecation_version, op_def.deprecation.explanation)) # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[ input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) if input_arg.number_attr and len( set(v.dtype.base_dtype for v in values)) > 1: raise TypeError() # All types should match. except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.internal_convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append( converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append( "<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, dtype.name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, dtype.name)) else: raise TypeError("%s that don't all match." % prefix) else: raise TypeError("%s that are invalid. Tensors: %s" % (prefix, values)) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] try: values = ops.internal_convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( "Expected %s passed to parameter '%s' of op '%s', got %s of " "type '%s' instead. Error: %s" % (dtypes.as_dtype(dtype).name, input_arg.name, op_type_name, repr(values), type(values).__name__, err)) except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.internal_convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( "Tried to convert '%s' to a tensor and failed. Error: %s" % (input_name, err)) prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, dtypes.as_dtype(input_arg.type).name)) else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( "%s type %s of argument '%s'." % (prefix, dtypes.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError("All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type %s that does not " "match type %s of argument '%s'." % (input_name, op_type_name, dtypes.as_dtype(attr_value).name, dtypes.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( dtypes.as_dtype(x).name for x in attr_value), ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr( op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError(( "'%s' Op requires that input '%s' be a mutable tensor " "(e.g.: a tf.Variable)") % (op_type_name, input_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) attr_value.list.SetInParent() if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text( attr_value.s), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, compat.as_text(x), '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) elif attr_def.type == "func": attr_value.func.CopyFrom(_MakeFunc(value, key)) elif attr_def.type == "list(func)": attr_value.list.func.extend([_MakeFunc(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [ val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref ] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError(f"Unrecognized Op name {op_type_name}") # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") # Default name if not specified. if name is None: name = op_type_name # Check for deprecation deprecation_version = op_def.deprecation.version if deprecation_version: producer = g.graph_def_versions.producer if producer >= deprecation_version: raise NotImplementedError( f"Op {op_type_name} is not available in GraphDef version {producer}. " f"It has been removed in version {deprecation_version}. " f"{op_def.deprecation.explanation}.") # Fill in the list of default types for all "type" attrs. This # will be used to choose a preferred dtype to convert to in the # absence of input type information. # # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. default_type_attr_map = {} allowed_list_attr_map = {} for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) if attr_def.HasField("allowed_values"): allowed_list_attr_map[key] = attr_def.allowed_values.list.type # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError(f"No argument for input {input_name} found in {op_def}") # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( f"Expected list for '{input_name}' argument to '{op_type_name}' " f"Op, not {values}.") # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) all_types = set(v.dtype.base_dtype for v in values) if input_arg.number_attr and len(all_types) > 1: # All types should match. raise TypeError(f"Not all types matched for {input_arg.name} for " f"{op_type_name}. Got {all_types}") except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("<NOT CONVERTIBLE TO TENSOR>") observed = ", ".join(observed_types) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} that do not match expected type " f"{dtype.name}.") elif input_arg.type_attr in attrs: raise TypeError(f"{prefix} that do not match type {dtype.name} " "inferred from earlier arguments.") else: raise TypeError(f"{prefix} that don't all match.") else: raise TypeError(f"{prefix} that are invalid. Tensors: {values}") types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None allowed_list = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] allowed_list = allowed_list_attr_map.get(input_arg.type_attr) try: # First see if we can get a valid dtype with the default conversion # and see if it matches an allowed dtypes. Some ops like ConcatV2 may # not list allowed dtypes, in which case we should skip this. if dtype is None and allowed_list: inferred = None try: inferred = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref) except TypeError as err: # When converting a python object such as a list of Dimensions, we # need a dtype to be specified, thus tensor conversion may throw # an exception which we will ignore and try again below. pass # If we did not match an allowed dtype, try again with the default # dtype. This could be because we have an empty tensor and thus we # picked the wrong type. if inferred is not None and inferred.dtype in allowed_list: values = inferred else: values = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) else: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " f"'{input_arg.name}' of op '{op_type_name}', got " f"{repr(values)} of type '{type(values).__name__}' instead. " f"Error: {err}") except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( f"Tried to convert '{input_name}' to a tensor and failed. " f"Error: {err}") prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} expected type of " f"{dtypes.as_dtype(input_arg.type).name}.") else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( f"{prefix} type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} must match length " f"{attrs[input_arg.number_attr]} of argument " f"'{inferred_from[input_arg.number_attr]}'.") else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} shorter than minimum length " f"{num_attr.minimum}.") # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError( f"All tensors passed to '{input_name}' of '{op_type_name}' Op " f"must have the same type. Got {base_types} instead.") if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: # If it's in default_type_attr_map, then wait to set it # (in "process remaining attrs", below). if input_arg.type_attr not in default_type_attr_map: raise TypeError( "Don't know how to infer type variable from empty input " f"list passed to input '{input_name}' of '{op_type_name}' " "Op.") else: attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type " f"{dtypes.as_dtype(attr_value).name} that does not match type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: actual_types = ", ".join( dtypes.as_dtype(x).name for x in attr_value) expected_types = ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type list of " f"{actual_types} that does not match type list {expected_types}" f" of argument '{inferred_from[input_arg.type_list_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( f"'{op_type_name}' Op requires that input '{input_name}' be a " "mutable tensor (e.g.: a tf.Variable)") input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( f"Should not specify value for inferred attr '{attr.name}' for " f"{op_type_name}.") continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") elif attr.name in default_type_attr_map: attrs[attr.name] = default_type_attr_map[attr.name] inferred_from.setdefault(attr.name, "Default in OpDef") else: raise TypeError(f"No argument found for attr {attr.name} for " f"{op_type_name}") # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] if attr_def.HasField("default_value") and value is None: attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) if attr_def.HasField("allowed_values"): if attr_def.type == "string": _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, op_type_name) elif attr_def.type == "list(string)": for value in attr_value.list.s: _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) if attr_def.has_minimum and attr_def.type == "int": _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) if attr_def.type == "type": _SatisfiesTypeConstraint(attr_value.type, attr_def, key) if attr_def.type == "list(type)": for value in attr_value.list.type: _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_structure = [] for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr, op_type_name) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) output_structure.append(len(t.list.type)) else: output_structure.append(None) if keywords: all_keywords = ", ".join(sorted(keywords.keys())) raise TypeError(f"{op_type_name} got unexpected keyword arguments: " f"{all_keywords}.") # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs
def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def, g, producer = _GetOpDef(op_type_name, keywords) name = name if name else op_type_name attrs, attr_protos = {}, {} default_type_attr_map, allowed_list_attr_map = {}, {} inputs, input_types, output_structure = [], [], [] fallback = True if (_CanExtractAttrsFastPath(op_def, keywords) and flags.config().graph_building_optimization.value()): fallback = False attr_protos, inputs, input_types, output_structure = ( op_def_library_pybind.process_inputs(op_type_name, producer, keywords)) if fallback: _CheckOpDeprecation(op_type_name, op_def, producer) _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, allowed_list_attr_map) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). with g.as_default(), ops.name_scope(name) as scope: if fallback: _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, keywords, default_type_attr_map, attrs, inputs, input_types) _ExtractRemainingAttrs(op_type_name, op_def, keywords, default_type_attr_map, attrs) _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos) del attrs # attrs is no longer authoritative, use attr_protos instead _ExtractOutputStructure(op_type_name, op_def, attr_protos, output_structure) _CheckAllInputsUsed(op_type_name, keywords) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs