def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): """Fill in default values for grad_ys. Args: grad_ys: List of gradients, can contain None. ys: List of tensors. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. Returns: A list of gradients to use, without None. Raises: ValueError: If one of the grad_ys is invalid. """ if len(grad_ys) != len(ys): raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys))) grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y") for i in xrange(len(grad_ys)): grad_y = grad_ys[i] y = ys[i] if grad_y is None: with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)): grad_ys[i] = array_ops.fill(array_ops.shape(y), constant_op.constant(1, dtype=y.dtype)) else: if grad_y.dtype != y.dtype: raise ValueError("Y and ys_grad must be of the same type, " "not y: %s, ys_grad: %s " % (types.as_dtype(y.dtype).name, types.as_dtype(grad_y.dtype).name)) return grad_ys
def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value=None, delta=1e-3): """Computes the theoretical and numerical jacobian.""" t = types.as_dtype(x.dtype) allowed_types = [types.float32, types.float64] assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name t2 = types.as_dtype(y.dtype) assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name if x_init_value is not None: i_shape = list(x_init_value.shape) assert (list(x_shape) == i_shape ), "x_shape = %s, init_data shape = %s" % (x_shape, i_shape) x_data = x_init_value else: if t == types.float32: dtype = np.float32 else: dtype = np.float64 x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype) jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx) jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta) return jacob_t, jacob_n
def _SatisfiesTypeConstraint(dtype, attr_def): if attr_def.HasField("allowed_values"): allowed_list = attr_def.allowed_values.list.type if dtype not in allowed_list: raise TypeError( "DataType %s for attr '%s' not in list of allowed values: %s" % (types_lib.as_dtype(dtype).name, attr_def.name, ", ".join( types_lib.as_dtype(x).name for x in allowed_list)))
def _SatisfiesTypeConstraint(dtype, attr_def): if attr_def.HasField("allowed_values"): allowed_list = attr_def.allowed_values.list.type if dtype not in allowed_list: raise TypeError( "DataType %s for attr '%s' not in list of allowed values: %s" % (types_lib.as_dtype(dtype).name, attr_def.name, ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
def testIsFloating(self): self.assertEqual(types.as_dtype("int8").is_floating, False) self.assertEqual(types.as_dtype("int16").is_floating, False) self.assertEqual(types.as_dtype("int32").is_floating, False) self.assertEqual(types.as_dtype("int64").is_floating, False) self.assertEqual(types.as_dtype("uint8").is_floating, False) self.assertEqual(types.as_dtype("complex64").is_floating, False) self.assertEqual(types.as_dtype("float32").is_floating, True) self.assertEqual(types.as_dtype("float64").is_floating, True) self.assertEqual(types.as_dtype("string").is_floating, False) self.assertEqual(types.as_dtype("bool").is_floating, False)
def testIsInteger(self): self.assertEqual(types.as_dtype("int8").is_integer, True) self.assertEqual(types.as_dtype("int16").is_integer, True) self.assertEqual(types.as_dtype("int32").is_integer, True) self.assertEqual(types.as_dtype("int64").is_integer, True) self.assertEqual(types.as_dtype("uint8").is_integer, True) self.assertEqual(types.as_dtype("complex64").is_integer, False) self.assertEqual(types.as_dtype("float").is_integer, False) self.assertEqual(types.as_dtype("double").is_integer, False) self.assertEqual(types.as_dtype("string").is_integer, False) self.assertEqual(types.as_dtype("bool").is_integer, False)
def testAllTypesConvertibleToNumpyDtype(self): for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue dtype = types.as_dtype(datatype_enum) numpy_dtype = dtype.as_numpy_dtype _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) if dtype.base_dtype != types.bfloat16: # NOTE(mdevin): Intentionally no way to feed a DT_BFLOAT16. self.assertEqual( types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
def testAllTypesConvertibleToNumpyDtype(self): for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue dtype = types.as_dtype(datatype_enum) numpy_dtype = dtype.as_numpy_dtype _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype) if dtype.base_dtype != types.bfloat16: # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16. self.assertEqual( types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
def ones(shape, dtype=types.float32, name=None): """Creates a tensor with all elements set to 1. This operation returns a tensor of type `dtype` with shape `shape` and all elements set to 1. For example: ```python tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]] ``` Args: shape: Either a list of integers, or a 1-D `Tensor` of type `int32`. dtype: The type of an element in the resulting `Tensor`. name: A name for the operation (optional). Returns: A `Tensor` with all elements set to 1. """ with ops.op_scope([shape], name, "ones") as name: if isinstance(shape, list): output = constant(1, shape=shape, dtype=dtype, name=name) else: shape = ops.convert_to_tensor(shape, name="shape") output = fill(shape, constant(1, dtype=dtype), name=name) assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype return output
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, name="restore_slice", preferred_shard=-1): """Restore a tensor slice from a set of files with a given pattern. Example usage: RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT) Args: file_pattern: the file pattern used to match a set of checkpoint files. tensor_name: the name of the tensor to restore. shape_and_slice: the shape-and-slice spec of the slice. tensor_type: the type of the tensor to restore. name: string. Optional name for the op. preferred_shard: Int. Optional shard to open first in the checkpoint file. Returns: A tensor of type "tensor_type". """ base_type = types.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice(file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name)
def __init__(self, key_dtype, value_dtype, default_value, table_ref): """Construct a table object from a table reference. Args: key_dtype: The table key type. value_dtype: The table value type. default_value: The value to use if a key is missing in the table. table_ref: The table reference, i.e. the output of the lookup table ops. """ self._key_dtype = types.as_dtype(key_dtype) self._value_dtype = types.as_dtype(value_dtype) self._shapes = [tensor_shape.TensorShape([1])] self._table_ref = table_ref self._name = self._table_ref.op.name.split("/")[-1] self._default_value = ops.convert_to_tensor(default_value, dtype=self._value_dtype) self._default_value.get_shape().merge_with(tensor_shape.scalar())
def _MakeType(v, attr_def): try: v = types_lib.as_dtype(v) except TypeError: raise TypeError("Expected DataType for argument '%s' not %s." % (attr_def.name, repr(v))) i = v.as_datatype_enum _SatisfiesTypeConstraint(i, attr_def) return i
def testDTypesHaveUniqueNames(self): dtypes = [] names = set() for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue dtype = types.as_dtype(datatype_enum) dtypes.append(dtype) names.add(dtype.name) self.assertEqual(len(dtypes), len(names))
def _VerifyGeneratedGradients(grads, op): """Verify that gradients are valid in number and type. Args: grads: List of generated gradients. op: Operation for which the gradients where generated. Raises: ValueError: if the gradients are invalid. """ if len(grads) != len(op.inputs): raise ValueError("Num gradients %d generated for op %s do not match num " "inputs %d" % (len(grads), op.node_def, len(op.inputs))) for i in xrange(len(grads)): grad = grads[i] inp = op.inputs[i] if grad is not None: if not grad.dtype.is_compatible_with(inp.dtype): raise ValueError( "Gradient type %s generated for op %s does " "not match input type %s" % (types.as_dtype(grad.dtype).name, op.node_def, types.as_dtype(inp.dtype).name))
def _ComputeGradient(x, x_shape, dx, y, y_shape, dy, x_init_value=None, delta=1e-3): """Computes the theoretical and numerical jacobian.""" t = types.as_dtype(x.dtype) allowed_types = [types.float32, types.float64] assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name t2 = types.as_dtype(y.dtype) assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name if x_init_value is not None: i_shape = list(x_init_value.shape) assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % ( x_shape, i_shape) x_data = x_init_value else: if t == types.float32: dtype = np.float32 else: dtype = np.float64 x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype) jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx) jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta) return jacob_t, jacob_n
def testMinMax(self): # make sure min/max evaluates for all data types that have min/max for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue dtype = types.as_dtype(datatype_enum) numpy_dtype = dtype.as_numpy_dtype # ignore types for which there are no minimum/maximum (or we cannot # compute it, such as for the q* types) if (dtype.is_quantized or dtype.base_dtype == types.bool or dtype.base_dtype == types.string or dtype.base_dtype == types.complex64): continue print("%s: %s - %s" % (dtype, dtype.min, dtype.max)) # check some values that are known if numpy_dtype == np.bool_: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 1) if numpy_dtype == np.int8: self.assertEquals(dtype.min, -128) self.assertEquals(dtype.max, 127) if numpy_dtype == np.int16: self.assertEquals(dtype.min, -32768) self.assertEquals(dtype.max, 32767) if numpy_dtype == np.int32: self.assertEquals(dtype.min, -2147483648) self.assertEquals(dtype.max, 2147483647) if numpy_dtype == np.int64: self.assertEquals(dtype.min, -9223372036854775808) self.assertEquals(dtype.max, 9223372036854775807) if numpy_dtype == np.uint8: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 255) if numpy_dtype == np.uint16: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 4294967295) if numpy_dtype == np.uint32: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 18446744073709551615) if numpy_dtype in (np.float16, np.float32, np.float64): self.assertEquals(dtype.min, np.finfo(numpy_dtype).min) self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
def testMinMax(self): # make sure min/max evaluates for all data types that have min/max for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue dtype = types.as_dtype(datatype_enum) numpy_dtype = dtype.as_numpy_dtype # ignore types for which there are no minimum/maximum (or we cannot # compute it, such as for the q* types) if (dtype.is_quantized or dtype.base_dtype == types.bool or dtype.base_dtype == types.string or dtype.base_dtype == types.complex64): continue print "%s: %s - %s" % (dtype, dtype.min, dtype.max) # check some values that are known if numpy_dtype == np.bool_: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 1) if numpy_dtype == np.int8: self.assertEquals(dtype.min, -128) self.assertEquals(dtype.max, 127) if numpy_dtype == np.int16: self.assertEquals(dtype.min, -32768) self.assertEquals(dtype.max, 32767) if numpy_dtype == np.int32: self.assertEquals(dtype.min, -2147483648) self.assertEquals(dtype.max, 2147483647) if numpy_dtype == np.int64: self.assertEquals(dtype.min, -9223372036854775808) self.assertEquals(dtype.max, 9223372036854775807) if numpy_dtype == np.uint8: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 255) if numpy_dtype == np.uint16: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 4294967295) if numpy_dtype == np.uint32: self.assertEquals(dtype.min, 0) self.assertEquals(dtype.max, 18446744073709551615) if numpy_dtype in (np.float16, np.float32, np.float64): self.assertEquals(dtype.min, np.finfo(numpy_dtype).min) self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
def __init__(self, op, value_index, dtype): """Creates a new `Tensor`. Args: op: An `Operation`. `Operation` that computes this tensor. value_index: An `int`. Index of the operation's endpoint that produces this tensor. dtype: A `types.DType`. Type of data stored in this tensor. Raises: TypeError: If the op is not an `Operation`. """ if not isinstance(op, Operation): raise TypeError("op needs to be an Operation: %s" % op) self._op = op self._value_index = value_index self._dtype = types.as_dtype(dtype) self._shape = tensor_shape.unknown_shape() # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. self._consumers = []
def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, name="restore_slice", preferred_shard=-1): """Restore a tensor slice from a set of files with a given pattern. Example usage: RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT) Args: file_pattern: the file pattern used to match a set of checkpoint files. tensor_name: the name of the tensor to restore. shape_and_slice: the shape-and-slice spec of the slice. tensor_type: the type of the tensor to restore. name: string. Optional name for the op. preferred_shard: Int. Optional shard to open first in the checkpoint file. Returns: A tensor of type "tensor_type". """ base_type = types.as_dtype(tensor_type).base_dtype return gen_io_ops._restore_slice( file_pattern, tensor_name, shape_and_slice, base_type, preferred_shard, name=name)
def MakeNdarray(tensor): """Create a numpy ndarray from a tensor. Create a numpy ndarray with the same shape and data as the tensor. Args: tensor: A TensorProto. Returns: A numpy array with the tensor contents. Raises: TypeError: if tensor has unsupported type. """ shape = [d.size for d in tensor.tensor_shape.dim] num_elements = np.prod(shape) tensor_dtype = types.as_dtype(tensor.dtype) dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) elif tensor_dtype == types.float32: if len(tensor.float_val) == 1: return np.repeat(np.array(tensor.float_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.float64: if len(tensor.double_val) == 1: return np.repeat(np.array(tensor.double_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) elif tensor_dtype in [ types.int32, types.uint8, types.int16, types.int8, types.qint32, types.quint8, types.qint8, types.bfloat16 ]: if len(tensor.int_val) == 1: return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.int64: if len(tensor.int64_val) == 1: return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.string: if len(tensor.string_val) == 1: return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), num_elements).reshape(shape) else: return np.array([str(x) for x in tensor.string_val], dtype=dtype).reshape(shape) elif tensor_dtype == types.complex64: it = iter(tensor.scomplex_val) if len(tensor.scomplex_val) == 2: return np.repeat( np.array(complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), dtype=dtype), num_elements).reshape(shape) else: return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape) elif tensor_dtype == types.bool: if len(tensor.bool_val) == 1: return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) else: raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
def make_tensor_proto(values, dtype=None, shape=None): """Create a TensorProto. Args: values: Values to put in the TensorProto. dtype: Optional tensor_pb2 DataType value. shape: List of integers representing the dimensions of tensor. Returns: A TensorProto. Depending on the type, it may contain data in the "tensor_content" attribute, which is not directly useful to Python programs. To access the values you should convert the proto back to a numpy ndarray with tensor_util.MakeNdarray(proto). Raises: TypeError: if unsupported types are provided. ValueError: if arguments have inappropriate values. make_tensor_proto accepts "values" of a python scalar, a python list, a numpy ndarray, or a numpy scalar. If "values" is a python scalar or a python list, make_tensor_proto first convert it to numpy ndarray. If dtype is None, the conversion tries its best to infer the right numpy data type. Otherwise, the resulting numpy array has a compatible data type with the given dtype. In either case above, the numpy ndarray (either the caller provided or the auto converted) must have the compatible type with dtype. make_tensor_proto then converts the numpy array to a tensor proto. If "shape" is None, the resulting tensor proto represents the numpy array precisely. Otherwise, "shape" specifies the tensor's shape and the numpy array can not have more elements than what "shape" specifies. """ if dtype: dtype = types.as_dtype(dtype) # We first convert value to a numpy array or scalar. if isinstance(values, (np.ndarray, np.generic)): if dtype: nparray = values.astype(dtype.as_numpy_dtype) else: nparray = values else: if values is None: raise ValueError("None values not supported.") # if dtype is provided, forces numpy array to be the type # provided if possible. np_dt = dtype.as_numpy_dtype if dtype else None if np.prod(shape) == 0: nparray = np.empty(shape, dtype=np_dt) else: _AssertCompatible(values, dtype) nparray = np.array(values, dtype=np_dt) if list(nparray.shape) != _GetDenseDimensions(values): raise ValueError("Argument must be a dense tensor: %s" % values) # python/numpy default float type is float64. We prefer float32 instead. if (nparray.dtype == np.float64) and dtype is None: nparray = nparray.astype(np.float32) # python/numpy default int type is int64. We prefer int32 instead. elif (nparray.dtype == np.int64) and dtype is None: nparray = nparray.astype(np.int32) # if dtype is provided, it must be compatible with what numpy # conversion says. numpy_dtype = types.as_dtype(nparray.dtype) if numpy_dtype is None: raise TypeError("Unrecognized data type: %s" % nparray.dtype) # If dtype was specified and is a quantized type, we convert # numpy_dtype back into the quantized version. if dtype in [types.qint8, types.quint8, types.qint32]: numpy_dtype = dtype if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype: raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype)) # If shape is not given, get the shape from the numpy array. if shape is None: shape = nparray.shape is_same_size = True shape_size = nparray.size else: shape = [int(dim) for dim in shape] shape_size = np.prod(shape) is_same_size = shape_size == nparray.size if nparray.size > shape_size: raise ValueError( "Too many elements provided. Needed at most %d, but received %d" % (shape_size, nparray.size)) tensor_proto = tensor_pb2.TensorProto( dtype=numpy_dtype.as_datatype_enum, tensor_shape=MakeTensorShapeProto(shape)) if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: tensor_proto.tensor_content = nparray.tostring() return tensor_proto # If we were not given values as a numpy array, compute the proto_values # from the given values directly, to avoid numpy trimming nulls from the # strings. Since values could be a list of strings, or a multi-dimensional # list of lists that might or might not correspond to the given shape, # we flatten it conservatively. if numpy_dtype == types.string and not isinstance(values, np.ndarray): proto_values = _FlattenToStrings(values) tensor_proto.string_val.extend([str(x) for x in proto_values]) return tensor_proto # TensorFlow expects C order (a.k.a., eigen row major). proto_values = nparray.ravel() append_fn = GetNumpyAppendFn(proto_values.dtype) if append_fn is None: raise TypeError("Element type not supported in TensorProto: %s" % numpy_dtype.name) append_fn(tensor_proto, proto_values) return tensor_proto
def testAllTypesConvertibleToDType(self): for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue self.assertEqual( datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)
def convert_to_tensor(value, dtype=None, name=None): """Converts the given `value` to a `Tensor`. This function converts Python objects of various types to `Tensor` objects. It accepts `Tensor` objects, numpy arrays, Python lists, and Python scalars. For example: ```python import numpy as np array = np.random.rand((32, 100, 100)) def my_func(arg): arg = tf.convert_to_tensor(arg, dtype=tf.float32) return tf.matmul(arg, arg) + arg # The following calls are equivalent. value_1 = my_func(tf.constant([[1.0, 2.0], [3.0, 4.0]])) value_2 = my_func([[1.0, 2.0], [3.0, 4.0]]) value_3 = my_func(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) ``` This function can be useful when composing a new operation in Python (such as `my_func` in the example above). All standard Python op constructors apply this function to each of their Tensor-valued inputs, which allows those ops to accept numpy arrays, Python lists, and scalars in addition to `Tensor` objects. Args: value: An object whose type has a registered `Tensor` conversion function. dtype: Optional element type for the returned tensor. If missing, the type is inferred from the type of `value`. name: Optional name to use if a new `Tensor` is created. Returns: A `Tensor` based on `value`. Raises: TypeError: If no conversion function is registered for `value`. RuntimeError: If a registered conversion function returns an invalid value. """ error_prefix = "" if name is None else "%s: " % name if dtype is not None: dtype = types.as_dtype(dtype) for _, funcs_at_priority in sorted(_tensor_conversion_func_registry.items()): for base_type, conversion_func in funcs_at_priority: if isinstance(value, base_type): ret = conversion_func(value, dtype=dtype, name=name) if not isinstance(ret, Tensor): raise RuntimeError( "%sConversion function %r for type %s returned non-Tensor: %r" % (error_prefix, conversion_func, base_type, ret)) if dtype and not dtype.is_compatible_with(ret.dtype): raise RuntimeError( "%sConversion function %r for type %s returned incompatible " "dtype: requested = %s, actual = %s" % (error_prefix, conversion_func, base_type, dtype.name, ret.dtype.name)) return ret raise TypeError("%sCannot convert %r with type %s to Tensor: " "no conversion function registered." % (error_prefix, value, type(value)))
def gradients(ys, xs, grad_ys=None, name="gradients", colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None): """Constructs symbolic partial derivatives of `ys` w.r.t. x in `xs`. `ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys` is a list of `Tensor`, holding the gradients received by the `ys`. The list must be the same length as `ys`. `gradients()` adds ops to the graph to output the partial derivatives of `ys` with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)` for y in `ys`. `grad_ys` is a list of tensors of the same length as `ys` that holds the initial gradients for each y in `ys`. When `grad_ys` is None, we fill in a tensor of '1's of the shape of y for each y in `ys`. A user can provide their own initial 'grad_ys` to compute the derivatives using a different initial gradient for each y (e.g., if one wanted to weight the gradient differently for each value in each y). Args: ys: A `Tensor` or list of tensors to be differentiated. xs: A `Tensor` or list of tensors to be used for differentiation. grad_ys: Optional. A `Tensor` or list of tensors the same size as `ys` and holding the gradients computed for each y in `ys`. name: Optional name to use for grouping all the gradient ops together. defaults to 'gradients'. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. gate_gradients: If True, add a tuple around the gradients returned for an operations. This avoids some race conditions. aggregation_method: Specifies the method used to combine gradient terms. Accepted values are constants defined in the class `AggregationMethod`. Returns: A list of `sum(dy/dx)` for each x in `xs`. Raises: LookupError: if one of the operations between `x` and `y` does not have a registered gradient function. ValueError: if the arguments are invalid. """ ys = _AsList(ys) xs = _AsList(xs) if grad_ys is None: grad_ys = [None] * len(ys) else: grad_ys = _AsList(grad_ys) with ops.op_scope(ys + xs + grad_ys, name, "gradients"): ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") xs = ops.convert_n_to_tensor_or_indexed_slices(xs, name="x") grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops) # The approach we take here is as follows: Create a list of all ops in the # subgraph between the ys and xs. Visit these ops in reverse order of ids # to ensure that when we visit an op the gradients w.r.t its outputs have # been collected. Then aggregate these gradients if needed, call the op's # gradient function, and add the generated gradients to the gradients for # its input. # Initialize the pending count for ops in the connected subgraph from ys # to the xs. to_ops = [t.op for t in ys] from_ops = [t.op for t in xs] pending_count, has_control_flow = _PendingCount( ops.get_default_graph(), to_ops, from_ops) # Iterate over the collected ops. # # grads: op => list of gradients received on each output endpoint of the # op. The gradients for each endpoint are initially collected as a list. # When it is time to call the op's gradient function, for each endpoint we # aggregate the list of received gradients into a Add() Operation if there # is more than one. grads = {} # Add the initial gradients for the ys. for y, grad_y in zip(ys, grad_ys): _SetGrad(grads, y, grad_y) # Initialize queue with to_ops. queue = collections.deque() # Add the ops in 'to_ops' into the queue. to_ops_set = set() for op in to_ops: if op._id not in to_ops_set: to_ops_set.add(op._id) queue.append(op) # The set of 'from_ops'. stop_ops = _StopOps(from_ops, pending_count) while queue: # generate gradient subgraph for op. op = queue.popleft() with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)): if has_control_flow: control_flow_ops.EnterGradWhileContext(op) out_grads = _AggregatedGrads(grads, op, has_control_flow, aggregation_method) grad_fn = None if any(out_grads) and op._id not in stop_ops: # A grad_fn must be defined, either as a function or as None # for ops that do not have gradients. try: grad_fn = ops.get_gradient_function(op) except LookupError: raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) if grad_fn and any(out_grads): # NOTE: If _AggregatedGrads didn't compute a value for the i'th # output, it means that the cost does not depend on output[i], # therefore dC/doutput[i] is 0. for i, out_grad in enumerate(out_grads): if (not out_grad and types.as_dtype(op.outputs[i].dtype).base_dtype in ( types.float32, types.float64)): # Only floating-point outputs get a zero gradient. Gradient # functions should ignore the gradient for other outputs. out_grads[i] = array_ops.zeros_like(op.outputs[i]) with ops.name_scope(op.name + "_grad"): # pylint: disable=protected-access with ops.get_default_graph()._original_op(op): # pylint: enable=protected-access op_wrapper = op if has_control_flow: op_wrapper = control_flow_ops.MakeWrapper(op) in_grads = _AsList(grad_fn(op_wrapper, *out_grads)) _VerifyGeneratedGradients(in_grads, op) if gate_gradients and len(in_grads) > 1: in_grads = control_flow_ops.tuple(in_grads) logging.vlog(1, "Gradient for '" + op.name + "'") logging.vlog(1, " in --> %s", ", ".join([x.name for x in out_grads if x])) logging.vlog(1, " out --> %s", ", ".join([x.name for x in in_grads if x])) else: # If no grad_fn is defined or none of out_grads is available, # just propagates a list of None backwards. in_grads = [None] * len(op.inputs) for t_in, in_grad in zip(op.inputs, in_grads): if in_grad: _SetGrad(grads, t_in, in_grad) if has_control_flow: control_flow_ops.ExitGradWhileContext(op) # update pending count for the inputs of op. for x in op.inputs: pending_count[x.op._id] -= 1 ready = (pending_count[x.op._id] == 0) if has_control_flow and not ready: ready = (pending_count[x.op._id] > 0 and control_flow_ops.IsLoopSwitch(x.op)) if ready: queue.append(x.op) for x in op.control_inputs: pending_count[x._id] -= 1 if pending_count[x._id] is 0: queue.append(x) return [_GetGrad(grads, x) for x in xs]
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements'. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map' is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all( isinstance(k, six.string_types) for k in input_map.keys())): raise TypeError( 'input_map must be a dictionary mapping strings to ' 'Tensor objects.') if (return_elements is not None and not (isinstance(return_elements, (list, tuple)) and all( isinstance(x, six.string_types) for x in return_elements))): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() with ops.name_scope('_inputs'): input_map = { k: ops.convert_to_tensor(v) for k, v in input_map.items() } # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def'. # 1. Add operations without their inputs. for node in graph_def.node: output_types = _OutputTypes(node, op_dict) with _MaybeDevice(node.device): name_to_op[node.name] = g.create_op(node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name, ))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError( _InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name, ))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def'. operation_name, output_index = _ParseTensorName( input_name) try: source_op = name_to_op[operation_name] source_tensor = list( source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name, ))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r %s' % (input_name, te.message))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (", ".join( types_lib.as_dtype(x).name for x in input_types), ", ".join( x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset( input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append( name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def make_tensor_proto(values, dtype=None, shape=None): """Create a TensorProto. Args: values: Values to put in the TensorProto. dtype: Optional tensor_pb2 DataType value. shape: List of integers representing the dimensions of tensor. Returns: A TensorProto. Depending on the type, it may contain data in the "tensor_content" attribute, which is not directly useful to Python programs. To access the values you should convert the proto back to a numpy ndarray with tensor_util.MakeNdarray(proto). Raises: TypeError: if unsupported types are provided. ValueError: if arguments have inappropriate values. make_tensor_proto accepts "values" of a python scalar, a python list, a numpy ndarray, or a numpy scalar. If "values" is a python scalar or a python list, make_tensor_proto first convert it to numpy ndarray. If dtype is None, the conversion tries its best to infer the right numpy data type. Otherwise, the resulting numpy array has a compatible data type with the given dtype. In either case above, the numpy ndarray (either the caller provided or the auto converted) must have the compatible type with dtype. make_tensor_proto then converts the numpy array to a tensor proto. If "shape" is None, the resulting tensor proto represents the numpy array precisely. Otherwise, "shape" specifies the tensor's shape and the numpy array can not have more elements than what "shape" specifies. """ if dtype: dtype = types.as_dtype(dtype) # We first convert value to a numpy array or scalar. if isinstance(values, (np.ndarray, np.generic)): if dtype: nparray = values.astype(dtype.as_numpy_dtype) else: nparray = values else: if values is None: raise ValueError("None values not supported.") # if dtype is provided, forces numpy array to be the type # provided if possible. np_dt = dtype.as_numpy_dtype if dtype else None if np.prod(shape) == 0: nparray = np.empty(shape, dtype=np_dt) else: _AssertCompatible(values, dtype) nparray = np.array(values, dtype=np_dt) if list(nparray.shape) != _GetDenseDimensions(values): raise ValueError("Argument must be a dense tensor: %s" % values) # python/numpy default float type is float64. We prefer float32 instead. if (nparray.dtype == np.float64) and dtype is None: nparray = nparray.astype(np.float32) # python/numpy default int type is int64. We prefer int32 instead. elif (nparray.dtype == np.int64) and dtype is None: nparray = nparray.astype(np.int32) # if dtype is provided, it must be compatible with what numpy # conversion says. numpy_dtype = types.as_dtype(nparray.dtype) if numpy_dtype is None: raise TypeError("Unrecognized data type: %s" % nparray.dtype) # If dtype was specified and is a quantized type, we convert # numpy_dtype back into the quantized version. if dtype in [types.qint8, types.quint8, types.qint32]: numpy_dtype = dtype if dtype is not None and not dtype.base_dtype == numpy_dtype.base_dtype: raise TypeError("Incompatible types: %s vs. %s" % (dtype, nparray.dtype)) # If shape is not given, get the shape from the numpy array. if shape is None: shape = nparray.shape is_same_size = True shape_size = nparray.size else: shape = [int(dim) for dim in shape] shape_size = np.prod(shape) is_same_size = shape_size == nparray.size if nparray.size > shape_size: raise ValueError( "Too many elements provided. Needed at most %d, but received %d" % (shape_size, nparray.size) ) tensor_proto = tensor_pb2.TensorProto(dtype=numpy_dtype.as_datatype_enum, tensor_shape=MakeTensorShapeProto(shape)) if is_same_size and numpy_dtype in _TENSOR_CONTENT_TYPES and shape_size > 1: tensor_proto.tensor_content = nparray.tostring() return tensor_proto # If we were not given values as a numpy array, compute the proto_values # from the given values directly, to avoid numpy trimming nulls from the # strings. Since values could be a list of strings, or a multi-dimensional # list of lists that might or might not correspond to the given shape, # we flatten it conservatively. if numpy_dtype == types.string and not isinstance(values, np.ndarray): proto_values = _FlattenToStrings(values) tensor_proto.string_val.extend([str(x) for x in proto_values]) return tensor_proto # TensorFlow expects C order (a.k.a., eigen row major). proto_values = nparray.ravel() append_fn = GetNumpyAppendFn(proto_values.dtype) if append_fn is None: raise TypeError("Element type not supported in TensorProto: %s" % numpy_dtype.name) append_fn(tensor_proto, proto_values) return tensor_proto
def get_variable(self, name, shape=None, dtype=types.float32, initializer=None, reuse=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. If a variable with the given name is already stored, we return the stored variable. Otherwise, we create a new one. Set `reuse` to `True` when you only want to reuse existing Variables. Set `reuse` to `False` when you only want to create new Variables. If `reuse` is `None` (the default), both new and existing variables are returned. If initializer is `None` (the default), the default initializer passed in the constructor is used. If that one is `None` too, we use a new `UniformUnitScalingInitializer`. Args: name: the name of the new or existing variable. shape: shape of the new or existing variable. dtype: type of the new or existing variable (defaults to `DT_FLOAT`). initializer: initializer for the variable. reuse: a Boolean or `None`. Controls reuse or creation of variables. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable). collections: List of graph collections keys to add the Variable to. Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable). Returns: The created or existing variable. Raises: ValueError: when creating a new variable and shape is not declared, when reusing a variable and specifying a conflicting shape, or when violating reuse during variable creation. """ should_check = reuse is not None dtype = types.as_dtype(dtype) shape = tensor_shape.as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. if should_check and not reuse: raise ValueError("Over-sharing: Variable %s already exists, disallowed." " Did you mean to set reuse=True in VarScope?" % name) found_var = self._vars[name] if not shape.is_compatible_with(found_var.get_shape()): raise ValueError("Trying to share variable %s, but specified shape %s" " and found shape %s." % (name, shape, found_var.get_shape())) if not dtype.is_compatible_with(found_var.dtype): dtype_str = dtype.name found_type_str = found_var.dtype.name raise ValueError("Trying to share variable %s, but specified dtype %s" " and found dtype %s." % (name, dtype_str, found_type_str)) return found_var # The code below handles only the case of creating a new variable. if should_check and reuse: raise ValueError("Under-sharing: Variable %s does not exist, disallowed." " Did you mean to set reuse=None in VarScope?" % name) if not shape.is_fully_defined(): raise ValueError("Shape of a new variable (%s) must be fully defined, " "but instead was %s." % (name, shape)) if initializer is None: initializer = init_ops.uniform_unit_scaling_initializer() with ops.name_scope(name + "/Initializer/"): init_val = initializer(shape.as_list(), dtype=dtype) v = variables.Variable(init_val, name=name, trainable=trainable, collections=collections) self._vars[name] = v logging.info("Created variable %s with shape %s and init %s", v.name, format(shape), initializer) return v
def import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None): """Imports the TensorFlow graph in `graph_def` into the Python `Graph`. This function provides a way to import a serialized TensorFlow [`GraphDef`](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto) protocol buffer, and extract individual objects in the `GraphDef` as [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a `GraphDef` proto. Args: graph_def: A `GraphDef` proto containing operations to be imported into the default graph. input_map: A dictionary mapping input names (as strings) in `graph_def` to `Tensor` objects. The values of the named input tensors in the imported graph will be re-mapped to the respective `Tensor` values. return_elements: A list of strings containing operation names in `graph_def` that will be returned as `Operation` objects; and/or tensor names in `graph_def` that will be returned as `Tensor` objects. name: (Optional.) A prefix that will be prepended to the names in `graph_def`. Defaults to `"import"`. op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos. Must contain an `OpDef` proto for each op type named in `graph_def`. If omitted, uses the `OpDef` protos registered in the global registry. Returns: A list of `Operation` and/or `Tensor` objects from the imported graph, corresponding to the names in `return_elements'. Raises: TypeError: If `graph_def` is not a `GraphDef` proto, `input_map' is not a dictionary mapping strings to `Tensor` objects, or `return_elements` is not a list of strings. ValueError: If `input_map`, or `return_elements` contains names that do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. it refers to an unknown tensor). """ # Type checks for inputs. if not isinstance(graph_def, graph_pb2.GraphDef): # `graph_def` could be a dynamically-created message, so try a duck-typed # approach try: old_graph_def = graph_def graph_def = graph_pb2.GraphDef() graph_def.MergeFrom(old_graph_def) except TypeError: raise TypeError('graph_def must be a GraphDef proto.') if input_map is None: input_map = {} else: if not (isinstance(input_map, dict) and all(isinstance(k, six.string_types) for k in input_map.keys())): raise TypeError('input_map must be a dictionary mapping strings to ' 'Tensor objects.') if (return_elements is not None and not (isinstance(return_elements, (list, tuple)) and all(isinstance(x, six.string_types) for x in return_elements))): raise TypeError('return_elements must be a list of strings.') # Use a canonical representation for all tensor names. input_map = {_CanonicalInputName(k): v for k, v in input_map.items()} used_input_keys = set() name_to_op = {} if op_dict is None: op_dict = op_def_registry.get_registered_ops() with ops.op_scope(input_map.values(), name, 'import'): g = ops.get_default_graph() with ops.name_scope('_inputs'): input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} # NOTE(mrry): We do this in two passes, because there may be a cycle in # `graph_def'. # 1. Add operations without their inputs. for node in graph_def.node: output_types = _OutputTypes(node, op_dict) with _MaybeDevice(node.device): name_to_op[node.name] = g.create_op( node.op, [], output_types, name=node.name, attrs=node.attr, compute_shapes=False) # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] input_types = _InputTypes(node, op_dict) # NOTE(mrry): We cannot use zip here because control inputs do not appear # in the list of input_types. for i, input_name in enumerate( [_CanonicalInputName(x) for x in node.input]): if _IsControlInput(input_name): # (a) Input is a control input that should be taken from an op # in "graph_def". try: source_op = name_to_op[input_name[1:]] except KeyError: raise ValueError( _InvalidNodeMessage( node, 'Control input %r not found in graph_def.' % (input_name,))) # pylint: disable=protected-access op._add_control_input(source_op) # pylint: enable=protected-access else: try: input_type = input_types[i] except IndexError: raise ValueError(_InvalidNodeMessage( node, 'More inputs specified (%r) than the op expects.' % (input_name,))) if input_name in input_map: # (b) Input should be replaced by a tensor from the caller. source_tensor = input_map[input_name] used_input_keys.add(input_name) else: # (c) Input should be taken from an op in `graph_def'. operation_name, output_index = _ParseTensorName(input_name) try: source_op = name_to_op[operation_name] source_tensor = list(source_op.values())[output_index] except (KeyError, IndexError): raise ValueError( _InvalidNodeMessage( node, 'Input tensor %r not found in graph_def.' % (input_name,))) try: # pylint: disable=protected-access op._add_input(source_tensor, dtype=input_type) # pylint: enable=protected-access except TypeError as te: raise ValueError( _InvalidNodeMessage(node, 'Input tensor %r %s' % (input_name, te.message))) # pylint: disable=protected_access if op._input_dtypes != input_types: raise ValueError( _InvalidNodeMessage( node, 'Input types mismatch (expected %r but got %r)' % (", ".join(types_lib.as_dtype(x).name for x in input_types), ", ".join(x.name for x in op._input_dtypes)))) # pylint: enable=protected_access # Execute shape inference for this op. # NOTE(mrry): If the graph contains a cycle, the full shape information # may not be available for this op's inputs. ops.set_shapes_for_outputs(op) # Treat unused input mappings as an error, because they are likely to be # due to a typo. unused_input_keys = frozenset(input_map.keys()).difference(used_input_keys) if unused_input_keys: raise ValueError( 'Attempted to map inputs that were not found in graph_def: [%s]' % ', '.join(unused_input_keys)) if return_elements is None: return None else: ret = [] for name in return_elements: if ':' in name: try: operation_name, output_index = _ParseTensorName(name) ret.append(name_to_op[operation_name].outputs[output_index]) except (ValueError, KeyError, IndexError): raise ValueError( 'Requested return_element %r not found in graph_def.' % name) else: try: ret.append(name_to_op[name]) except KeyError: raise ValueError( 'Requested return_element %r not found in graph_def.' % name) return ret
def _SingleArgToTypes(node_def, arg_def): types = _ArgToTypesNoRef(node_def, arg_def) if arg_def.is_ref: return [types_lib.as_dtype(dt).as_ref.as_datatype_enum for dt in types] return types
def testStringConversion(self): self.assertIs(types.float32, types.as_dtype("float32")) self.assertIs(types.float64, types.as_dtype("float64")) self.assertIs(types.int32, types.as_dtype("int32")) self.assertIs(types.uint8, types.as_dtype("uint8")) self.assertIs(types.int16, types.as_dtype("int16")) self.assertIs(types.int8, types.as_dtype("int8")) self.assertIs(types.string, types.as_dtype("string")) self.assertIs(types.complex64, types.as_dtype("complex64")) self.assertIs(types.int64, types.as_dtype("int64")) self.assertIs(types.bool, types.as_dtype("bool")) self.assertIs(types.qint8, types.as_dtype("qint8")) self.assertIs(types.quint8, types.as_dtype("quint8")) self.assertIs(types.qint32, types.as_dtype("qint32")) self.assertIs(types.bfloat16, types.as_dtype("bfloat16")) self.assertIs(types.float32_ref, types.as_dtype("float32_ref")) self.assertIs(types.float64_ref, types.as_dtype("float64_ref")) self.assertIs(types.int32_ref, types.as_dtype("int32_ref")) self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref")) self.assertIs(types.int16_ref, types.as_dtype("int16_ref")) self.assertIs(types.int8_ref, types.as_dtype("int8_ref")) self.assertIs(types.string_ref, types.as_dtype("string_ref")) self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref")) self.assertIs(types.int64_ref, types.as_dtype("int64_ref")) self.assertIs(types.bool_ref, types.as_dtype("bool_ref")) self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref")) self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref")) self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref")) self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref")) with self.assertRaises(TypeError): types.as_dtype("not_a_type")
def testNumpyConversion(self): self.assertIs(types.float32, types.as_dtype(np.float32)) self.assertIs(types.float64, types.as_dtype(np.float64)) self.assertIs(types.int32, types.as_dtype(np.int32)) self.assertIs(types.int64, types.as_dtype(np.int64)) self.assertIs(types.uint8, types.as_dtype(np.uint8)) self.assertIs(types.int16, types.as_dtype(np.int16)) self.assertIs(types.int8, types.as_dtype(np.int8)) self.assertIs(types.complex64, types.as_dtype(np.complex64)) self.assertIs(types.string, types.as_dtype(np.object)) self.assertIs(types.string, types.as_dtype(np.array(["foo", "bar"]).dtype)) self.assertIs(types.bool, types.as_dtype(np.bool)) with self.assertRaises(TypeError): types.as_dtype(np.dtype([("f1", np.uint), ("f2", np.int32)]))
def testInvalid(self): with self.assertRaises(TypeError): types.DType(types_pb2.DT_INVALID) with self.assertRaises(TypeError): types.as_dtype(types_pb2.DT_INVALID)
def get_variable(self, name, shape=None, dtype=types.float32, initializer=None, reuse=None, trainable=True, collections=None): """Gets an existing variable with these parameters or create a new one. If a variable with the given name is already stored, we return the stored variable. Otherwise, we create a new one. Set `reuse` to `True` when you only want to reuse existing Variables. Set `reuse` to `False` when you only want to create new Variables. If `reuse` is `None` (the default), both new and existing variables are returned. If initializer is `None` (the default), the default initializer passed in the constructor is used. If that one is `None` too, we use a new `UniformUnitScalingInitializer`. Args: name: the name of the new or existing variable. shape: shape of the new or existing variable. dtype: type of the new or existing variable (defaults to `DT_FLOAT`). initializer: initializer for the variable. reuse: a Boolean or `None`. Controls reuse or creation of variables. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable). collections: List of graph collections keys to add the Variable to. Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable). Returns: The created or existing variable. Raises: ValueError: when creating a new variable and shape is not declared, when reusing a variable and specifying a conflicting shape, or when violating reuse during variable creation. """ should_check = reuse is not None dtype = types.as_dtype(dtype) shape = tensor_shape.as_shape(shape) if name in self._vars: # Here we handle the case when returning an existing variable. if should_check and not reuse: raise ValueError("Over-sharing: Variable %s already exists, disallowed." " Did you mean to set reuse=True in VarScope?" % name) found_var = self._vars[name] if not shape.is_compatible_with(found_var.get_shape()): raise ValueError("Trying to share variable %s, but specified shape %s" " and found shape %s." % (name, str(shape), str(found_var.get_shape()))) if not dtype.is_compatible_with(found_var.dtype): dtype_str = dtype.name found_type_str = found_var.dtype.name raise ValueError("Trying to share variable %s, but specified dtype %s" " and found dtype %s." % (name, str(dtype_str), str(found_type_str))) return found_var # The code below handles only the case of creating a new variable. if should_check and reuse: raise ValueError("Under-sharing: Variable %s does not exist, disallowed." " Did you mean to set reuse=None in VarScope?" % name) if not shape.is_fully_defined(): raise ValueError("Shape of a new variable (%s) must be fully defined, " "but instead was %s." % (name, shape)) if initializer is None: initializer = init_ops.uniform_unit_scaling_initializer() with ops.name_scope(name + "/Initializer/"): init_val = initializer(shape.as_list(), dtype=dtype) v = variables.Variable(init_val, name=name, trainable=trainable, collections=collections) self._vars[name] = v logging.info("Created variable %s with shape %s and init %s", v.name, format(shape), str(initializer)) return v
def MakeNdarray(tensor): """Create a numpy ndarray from a tensor. Create a numpy ndarray with the same shape and data as the tensor. Args: tensor: A TensorProto. Returns: A numpy array with the tensor contents. Raises: TypeError: if tensor has unsupported type. """ shape = [d.size for d in tensor.tensor_shape.dim] num_elements = np.prod(shape) tensor_dtype = types.as_dtype(tensor.dtype) dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) elif tensor_dtype == types.float32: if len(tensor.float_val) == 1: return np.repeat(np.array(tensor.float_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.float64: if len(tensor.double_val) == 1: return np.repeat(np.array(tensor.double_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape) elif tensor_dtype in [ types.int32, types.uint8, types.int16, types.int8, types.qint32, types.quint8, types.qint8, types.bfloat16, ]: if len(tensor.int_val) == 1: return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.int64: if len(tensor.int64_val) == 1: return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape) elif tensor_dtype == types.string: if len(tensor.string_val) == 1: return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), num_elements).reshape(shape) else: return np.array([str(x) for x in tensor.string_val], dtype=dtype).reshape(shape) elif tensor_dtype == types.complex64: it = iter(tensor.scomplex_val) if len(tensor.scomplex_val) == 2: return np.repeat( np.array(complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), dtype=dtype), num_elements ).reshape(shape) else: return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape) elif tensor_dtype == types.bool: if len(tensor.bool_val) == 1: return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), num_elements).reshape(shape) else: return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape) else: raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
def apply_op(self, op_type_name, g=None, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Config proto extensions must be provided via the 'ext' keyword argument. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # If none of the inputs are Tensors and your session doesn't have a # default graph, you will have to specify the graph. op_def_library.apply_op("op", input1=input1, g=g) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. g: The graph context (optional) name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Need to specify g=graph to Op '%s' (could not determine graph due " "to: %s)" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break try: values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? values = ops.convert_n_to_tensor_or_indexed_slices(values) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s that do not match expected type %s." % (prefix, types_lib.as_dtype(dtype).name)) elif input_arg.type_attr in attrs: raise TypeError("%s that do not match type %s inferred from " "earlier arguments." % (prefix, types_lib.as_dtype(dtype).name)) else: raise TypeError("%s that don't all match." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] try: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values).dtype.name prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError("%s expected type of %s." % (prefix, types_lib.as_dtype(input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % (prefix, types_lib.as_dtype(attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join(types_lib.as_dtype(x).name for x in attr_value), ", ".join(types_lib.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint(base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError("Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, attr_value.s, '", "'.join(attr_def.allowed_values.list.s))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, x, '", "'.join(attr_def.allowed_values.list.s))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(t.list.type)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [types_lib.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError("apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # Add Op to graph if output_structure: op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) outputs = op.outputs return _Restructure(ops.convert_n_to_tensor_or_indexed_slices(outputs), output_structure) else: return g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def)
def apply_op(self, op_type_name, g=None, name=None, **keywords): # pylint: disable=g-doc-args """Add a node invoking a registered Op to a graph. Config proto extensions must be provided via the 'ext' keyword argument. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # If none of the inputs are Tensors and your session doesn't have a # default graph, you will have to specify the graph. op_def_library.apply_op("op", input1=input1, g=g) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. g: The graph context (optional) name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ op_info = self._ops.get(op_type_name, None) if op_info is None: raise RuntimeError("Unrecognized Op name " + op_type_name) op_def = op_info.op_def # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values()), graph=g) # pyline: enable=protected-access except AssertionError as e: raise RuntimeError( "Need to specify g=graph to Op '%s' (could not determine graph due " "to: %s)" % (op_type_name, e.message)) # Default name if not specified. if name is None: name = op_type_name # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). attrs = {} inputs = [] input_types = [] with g.as_default(), ops.name_scope(name) as scope: # Perform input type inference inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError("No argument for input " + input_name) # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( "Expected list for '%s' argument to '%s' Op, not %s." % (input_name, op_type_name, values)) # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break try: values = ops.convert_n_to_tensor_or_indexed_slices( values, name=input_arg.name, dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None) except (TypeError, ValueError): assert dtype is not None, "Should not fail if dtype is None" assert input_arg.number_attr, "Should be number_attr case" # What types does the conversion function think values have? values = ops.convert_n_to_tensor_or_indexed_slices( values) observed = ", ".join(v.dtype.base_dtype.name for v in values) prefix = ( "Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s that do not match expected type %s." % (prefix, types_lib.as_dtype(dtype).name)) elif input_arg.type_attr in attrs: raise TypeError( "%s that do not match type %s inferred from " "earlier arguments." % (prefix, types_lib.as_dtype(dtype).name)) else: raise TypeError("%s that don't all match." % prefix) types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] try: values = ops.convert_to_tensor(values, name=input_arg.name, dtype=dtype) except ValueError: # What type does convert_to_tensor think it has? observed = ops.convert_to_tensor(values).dtype.name prefix = ( "Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError( "%s expected type of %s." % (prefix, types_lib.as_dtype( input_arg.type).name)) else: raise TypeError( "%s type %s of argument '%s'." % (prefix, types_lib.as_dtype( attrs[input_arg.type_attr]).name, inferred_from[input_arg.type_attr])) types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # <number-attr> * <type> or <number-attr> * <type-attr> if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( "List argument '%s' to '%s' Op with length %d must match " "length %d of argument '%s'." % (input_name, op_type_name, len(values), attrs[input_arg.number_attr], inferred_from[input_arg.number_attr])) else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len( values) < num_attr.minimum: raise ValueError( "List argument '%s' to '%s' Op with length %d shorter " "than minimum length %d." % (input_name, op_type_name, len(values), num_attr.minimum)) # All tensors must have the same base type. if any([bt != base_types[0] for bt in base_types]): raise TypeError( "All tensors passed to '%s' of '%s' Op " "must have the same type." % (input_name, op_type_name)) if input_arg.type != types_pb2.DT_INVALID: # <number-attr> * <type> case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # <number-attr> * <type-attr> case, where <type-attr> already # has an inferred value. if base_types and base_types[0] != attrs[ input_arg.type_attr]: assert False, "Unreachable" else: # <number-attr> * <type-attr> case, where we are now setting # the <type-attr> based on this input if not base_types: raise TypeError( "Don't know how to infer type variable from empty input " "list passed to input '%s' of '%s' Op." % (input_name, op_type_name)) attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint(base_types[0], type_attr) elif input_arg.type_attr: # <type-attr> attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: assert False, "Unreachable" else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_attr)) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # <type-list-attr> attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: raise TypeError( "Input '%s' of '%s' Op has type list of %s that does not " "match type list %s of argument '%s'." % (input_name, op_type_name, ", ".join( types_lib.as_dtype(x).name for x in attr_value), ", ".join( types_lib.as_dtype(x).name for x in attrs[input_arg.type_list_attr]), inferred_from[input_arg.type_list_attr])) else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_list_attr)) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x.is_ref_dtype for x in types): raise TypeError( "Input '%s' of '%s' Op requires l-value input" % (input_name, op_type_name)) input_types.extend(types) else: input_types.extend(base_types) # Process remaining attrs for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( "Should not specify value for inferred attr '%s'." % attr.name) continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") else: raise TypeError("No argument for attr " + attr.name) # Convert attr values to AttrValue protos. attr_protos = {} for attr_def in op_def.attr: key = attr_def.name value = attrs[key] attr_value = attr_value_pb2.AttrValue() if attr_def.HasField("default_value") and value is None: attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue if attr_def.type.startswith("list("): if not _IsListValue(value): raise TypeError("Expected list for attr " + key) if attr_def.has_minimum: if len(value) < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed list of length %d " "less than minimum %d." % (key, op_type_name, len(value), attr_def.minimum)) if attr_def.type == "string": attr_value.s = _MakeStr(value, key) if attr_def.HasField("allowed_values"): if attr_value.s not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, attr_value.s, '", "'.join( attr_def.allowed_values.list.s))) elif attr_def.type == "list(string)": attr_value.list.s.extend([_MakeStr(x, key) for x in value]) if attr_def.HasField("allowed_values"): for x in attr_value.list.s: if x not in attr_def.allowed_values.list.s: raise ValueError( "Attr '%s' of '%s' Op passed string '%s' not in: \"%s\"." % (key, op_type_name, x, '", "'.join( attr_def.allowed_values.list.s))) elif attr_def.type == "int": attr_value.i = _MakeInt(value, key) if attr_def.has_minimum: if attr_value.i < attr_def.minimum: raise ValueError( "Attr '%s' of '%s' Op passed %d less than minimum %d." % (key, op_type_name, attr_value.i, attr_def.minimum)) elif attr_def.type == "list(int)": attr_value.list.i.extend([_MakeInt(x, key) for x in value]) elif attr_def.type == "float": attr_value.f = _MakeFloat(value, key) elif attr_def.type == "list(float)": attr_value.list.f.extend( [_MakeFloat(x, key) for x in value]) elif attr_def.type == "bool": attr_value.b = _MakeBool(value, key) elif attr_def.type == "list(bool)": attr_value.list.b.extend( [_MakeBool(x, key) for x in value]) elif attr_def.type == "type": attr_value.type = _MakeType(value, attr_def) elif attr_def.type == "list(type)": attr_value.list.type.extend( [_MakeType(x, attr_def) for x in value]) elif attr_def.type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, key)) elif attr_def.type == "list(shape)": attr_value.list.shape.extend( [_MakeShape(x, key) for x in value]) elif attr_def.type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, key)) elif attr_def.type == "list(tensor)": attr_value.list.tensor.extend( [_MakeTensor(x, key) for x in value]) else: raise TypeError("Unrecognized Attr type " + attr_def.type) attr_protos[key] = attr_value del attrs # attrs is no longer authoritative, use attr_protos instead # Determine output types (possibly using attrs) output_types = [] output_structure = [] for arg in op_def.output_arg: types = [] if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr).i if arg.type_attr: types = [_AttrValue(attr_protos, arg.type_attr).type ] * n else: types = [arg.type] * n output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr) types = [t.type] output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr) types = t.list.type output_structure.append(len(t.list.type)) else: types = [arg.type] output_structure.append(None) if arg.is_ref: types = [types_lib.as_dtype(x).as_ref for x in types] output_types.extend(types) if keywords: raise TypeError( "apply_op() got unexpected keyword arguments: " + ", ".join(sorted(keywords.keys()))) # Add Op to graph if output_structure: op = g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) outputs = op.outputs return _Restructure( ops.convert_n_to_tensor_or_indexed_slices(outputs), output_structure) else: return g.create_op(op_type_name, inputs, output_types, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def)