def testPreserveTensorArrayShape(self): ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1, element_shape=(3, )) ta_s = structure.type_spec_from_value(ta) ta_after = structure.from_tensor_list( ta_s, structure.to_tensor_list(ta_s, ta)) self.assertEqual(ta_after.element_shape.as_list(), [3])
def testPreserveInferredTensorArrayShape(self): ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1) # Shape is inferred from the write. ta = ta.write(0, [1, 2, 3]) ta_s = structure.type_spec_from_value(ta) ta_after = structure.from_tensor_list( ta_s, structure.to_tensor_list(ta_s, ta)) self.assertEqual(ta_after.element_shape.as_list(), [3])
def preserveStaticShape(self): rt = ragged_factory_ops.constant([[1, 2], [], [3]]) rt_s = structure.type_spec_from_value(rt) rt_after = structure.from_tensor_list(rt_s, structure.to_tensor_list(rt_s, rt)) self.assertEqual(rt_after.row_splits.shape.as_list(), rt.row_splits.shape.as_list()) self.assertEqual(rt_after.values.shape.as_list(), [None]) st = sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]) st_s = structure.type_spec_from_value(st) st_after = structure.from_tensor_list(st_s, structure.to_tensor_list(st_s, st)) self.assertEqual(st_after.indices.shape.as_list(), [None, 2]) self.assertEqual(st_after.values.shape.as_list(), [None]) self.assertEqual(st_after.dense_shape.shape.as_list(), st.dense_shape.shape.as_list())
def py_function_wrapper(*args): nested_args = structure.from_compatible_tensor_list( self._input_structure, args) if not _should_unpack(nested_args): nested_args = (nested_args, ) ret = self._func(*nested_args) if _should_pack(ret): ret = tuple(ret) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret]
def compress(element): """Compress a dataset element. Args: element: A nested structure of types supported by Tensorflow. Returns: A variant tensor representing the compressed element. This variant can be passed to `uncompress` to get back the original element. """ element_spec = structure.type_spec_from_value(element) tensor_list = structure.to_tensor_list(element_spec, element) return ged_ops.compress_element(tensor_list)
def _next_func(string_handle): """Calls get_next for created iterator. Args: string_handle: An iterator string handle created by _init_func Returns: The elements generated from `input_dataset` """ with ops.device(self._source_device_string): iterator = iterator_ops.Iterator.from_string_handle( string_handle, dataset_ops.get_legacy_output_types(self), dataset_ops.get_legacy_output_shapes(self), dataset_ops.get_legacy_output_classes(self)) return structure.to_tensor_list(self.element_spec, iterator.get_next())
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects. Returns: An `Optional` that wraps `value`. """ with ops.name_scope("optional") as scope: with ops.name_scope("value"): value_structure = structure.type_spec_from_value(value) encoded_value = structure.to_tensor_list( value_structure, value) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), value_structure)
def from_value(value): """Returns an `Optional` that wraps the given value. Args: value: A value to wrap. The value must be convertible to `Tensor` or `CompositeTensor`. Returns: An `Optional` that wraps `value`. """ with ops.name_scope("optional") as scope: with ops.name_scope("value"): value_structure = structure.type_spec_from_value(value) encoded_value = structure.to_tensor_list(value_structure, value) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), value_structure)
def from_value(value): """Returns a `tf.experimental.Optional` that wraps the given value. >>> optional = tf.experimental.Optional.from_value(42) >>> print(optional.has_value()) tf.Tensor(True, shape=(), dtype=bool) >>> print(optional.get_value()) tf.Tensor(42, shape=(), dtype=int32) Args: value: A value to wrap. The value must be convertible to `tf.Tensor` or `tf.CompositeTensor`. Returns: A `tf.experimental.Optional` that wraps `value`. """ with ops.name_scope("optional") as scope: with ops.name_scope("value"): element_spec = structure.type_spec_from_value(value) encoded_value = structure.to_tensor_list(element_spec, value) return _OptionalImpl( gen_dataset_ops.optional_from_value(encoded_value, name=scope), element_spec)
def __init__(self, input_dataset, initial_state, scan_func, use_default_device=None): """See `scan()` for details.""" self._input_dataset = input_dataset self._initial_state = structure.normalize_element(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one # or more times below. self._state_structure = structure.type_spec_from_value(self._initial_state) # Iteratively rerun the scan function until reaching a fixed point on # `self._state_shapes`. need_to_rerun = True while need_to_rerun: wrapped_func = dataset_ops.StructuredFunctionWrapper( scan_func, self._transformation_name(), input_structure=(self._state_structure, input_dataset.element_spec), add_to_graph=False) if not (isinstance(wrapped_func.output_types, collections_abc.Sequence) and len(wrapped_func.output_types) == 2): raise TypeError("The scan function must return a pair comprising the " "new state and the output value.") new_state_classes, self._output_classes = wrapped_func.output_classes # Extract and validate class information from the returned values. new_state_classes, output_classes = wrapped_func.output_classes old_state_classes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access self._state_structure) for new_state_class, old_state_class in zip( nest.flatten(new_state_classes), nest.flatten(old_state_classes)): if not issubclass(new_state_class, old_state_class): raise TypeError( "The element classes for the new state must match the initial " "state. Expected %s; got %s." % (old_state_classes, new_state_classes)) # Extract and validate type information from the returned values. new_state_types, output_types = wrapped_func.output_types old_state_types = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access self._state_structure) for new_state_type, old_state_type in zip( nest.flatten(new_state_types), nest.flatten(old_state_types)): if new_state_type != old_state_type: raise TypeError( "The element types for the new state must match the initial " "state. Expected %s; got %s." % (old_state_types, new_state_types)) # Extract shape information from the returned values. new_state_shapes, output_shapes = wrapped_func.output_shapes old_state_shapes = nest.map_structure( lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access self._state_structure) self._element_spec = structure.convert_legacy_structure( output_types, output_shapes, output_classes) flat_state_shapes = nest.flatten(old_state_shapes) flat_new_state_shapes = nest.flatten(new_state_shapes) weakened_state_shapes = [ original.most_specific_compatible_shape(new) for original, new in zip(flat_state_shapes, flat_new_state_shapes) ] need_to_rerun = False for original_shape, weakened_shape in zip(flat_state_shapes, weakened_state_shapes): if original_shape.ndims is not None and ( weakened_shape.ndims is None or original_shape.as_list() != weakened_shape.as_list()): need_to_rerun = True break if need_to_rerun: # TODO(b/110122868): Support a "most specific compatible structure" # method for combining structures, to avoid using legacy structures # in this method. self._state_structure = structure.convert_legacy_structure( old_state_types, nest.pack_sequence_as(old_state_shapes, weakened_state_shapes), old_state_classes) self._scan_func = wrapped_func self._scan_func.function.add_to_graph(ops.get_default_graph()) # pylint: disable=protected-access if compat.forward_compatible(2019, 10, 15) or use_default_device is not None: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, use_default_device=use_default_device, **self._flat_structure) else: variant_tensor = gen_experimental_dataset_ops.scan_dataset( self._input_dataset._variant_tensor, structure.to_tensor_list(self._state_structure, self._initial_state), self._scan_func.function.captured_inputs, f=self._scan_func.function, preserve_cardinality=True, **self._flat_structure) super(_ScanDataset, self).__init__(input_dataset, variant_tensor)
def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructure a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.type_spec_from_value(value_0) flat_s_0 = structure.to_tensor_list(s_0, value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.type_spec_from_value(value_1) flat_s_1 = structure.to_tensor_list(s_1, value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.type_spec_from_value(value_2) flat_s_2 = structure.to_tensor_list(s_2, value_2) with self.assertRaisesRegex( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*int32.* and shape \(3,\)"): structure.to_tensor_list(s_0, value_1) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_0, value_2) with self.assertRaisesRegex( TypeError, "neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_1, value_0) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_1, value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_0) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_1) with self.assertRaisesRegex( ValueError, r"Cannot create a Tensor from the tensor list"): structure.from_tensor_list(s_0, flat_s_1) with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"): structure.from_tensor_list(s_0, flat_s_2) with self.assertRaisesRegex( ValueError, "Cannot create a SparseTensor from the tensor list"): structure.from_tensor_list(s_1, flat_s_0) with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"): structure.from_tensor_list(s_1, flat_s_2) with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"): structure.from_tensor_list(s_2, flat_s_0) with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"): structure.from_tensor_list(s_2, flat_s_1)
def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructure a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.type_spec_from_value(value_tensor) flat_tensor = structure.to_tensor_list(s_tensor, value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor) flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor, value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.type_spec_from_value(value_nest) flat_nest = structure.to_tensor_list(s_nest, value_nest) with self.assertRaisesRegex( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): structure.to_tensor_list(s_tensor, value_sparse_tensor) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_tensor, value_nest) with self.assertRaisesRegex( TypeError, "neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_sparse_tensor, value_tensor) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_sparse_tensor, value_nest) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_tensor) with self.assertRaisesRegex( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_sparse_tensor) with self.assertRaisesRegex( ValueError, "Cannot create a Tensor from the tensor list because item 0 " ".*tf.Tensor.* is incompatible with the expected TypeSpec " ".*TensorSpec.*"): structure.from_tensor_list(s_tensor, flat_sparse_tensor) with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_tensor, flat_nest) with self.assertRaisesRegex( ValueError, "Cannot create a SparseTensor from the tensor list because " "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec " ".*TensorSpec.*"): structure.from_tensor_list(s_sparse_tensor, flat_tensor) with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_sparse_tensor, flat_nest) with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_tensor) with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_sparse_tensor)
def wrapped_fn(*args): # pylint: disable=missing-docstring ret = wrapper_helper(*args) ret = structure.to_tensor_list(self._output_structure, ret) return [ops.convert_to_tensor(t) for t in ret]
def wrapped_fn(*args): ret = wrapper_helper(*args) return structure.to_tensor_list(self._output_structure, ret)