def testNestedNestedStructure(self): s = (structure.TensorStructure(dtypes.int64, []), (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = structure.to_tensor_list(s, nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = structure.from_tensor_list(s, tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (structure.from_compatible_tensor_list( s, tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t)
def testNestedNestedStructure(self): # Although `Structure.from_value()` will not construct one, a nested # structure containing nested `NestedStructure` objects can occur if a # structure is constructed manually. s = structure.NestedStructure( (structure.TensorStructure(dtypes.int64, []), structure.NestedStructure( (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = s._to_tensor_list(nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = s._from_tensor_list(tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (s._from_compatible_tensor_list(tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t)
def _make_window_size_func(self, window_size_func): """Make wrapping defun for window_size_func.""" def window_size_func_wrapper(key): return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) self._window_size_func = dataset_ops.StructuredFunctionWrapper( window_size_func_wrapper, self._transformation_name(), input_structure=structure.TensorStructure(dtypes.int64, [])) if not self._window_size_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`window_size_func` must return a single tf.int64 scalar tensor.")
def testFromNone(self): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops.Optional.none_from_structure(value_structure) self.assertTrue(opt.value_structure.is_compatible_with(value_structure)) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.float32, [1]))) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value())
def testImportedFunctionsRegistered(self): if test.is_built_with_gpu_support(): self.skipTest( "Disabling this new test due to errors with cuda and rocm") with ops.Graph().as_default() as graph: x = array_ops.placeholder(dtypes.variant, shape=[], name='foo') ds = dataset_ops.from_variant(x, structure=(structure.TensorStructure( dtypes.int32, []))) y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32), lambda p, q: p + q) graph_def = graph.as_graph_def() def fn_to_wrap(a): returned_elements = graph_def_importer.import_graph_def( graph_def, input_map={x.name: a}, return_elements=[y.name]) return returned_elements[0] wrapped_fn = wrap_function.wrap_function( fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)]) ds = dataset_ops.Dataset.from_tensor_slices([10, 20]) v = dataset_ops.to_variant(ds) self.evaluate(wrapped_fn(v))
def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value( ) gpu_optional_with_value_values = gpu_optional_with_value.get_value( ) gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value))
def write(self, dataset): """Returns a `tf.Operation` to write a dataset to a file. Args: dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: A `tf.Operation` that, when run, writes contents of `dataset` to a file. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") if not dataset_ops.get_structure(dataset).is_compatible_with( structure.TensorStructure(dtypes.string, [])): raise TypeError( "`dataset` must produce scalar `DT_STRING` tensors whereas it " "produces shape {0} and types {1}".format( dataset_ops.get_legacy_output_shapes(dataset), dataset_ops.get_legacy_output_types(dataset))) if compat.forward_compatible(2019, 8, 3): return gen_experimental_dataset_ops.dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access else: return gen_experimental_dataset_ops.experimental_dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
def __init__(self, input_dataset, predicate): """See `take_while()` for details.""" self._input_dataset = input_dataset wrapped_func = dataset_ops.StructuredFunctionWrapper( predicate, "tf.data.experimental.take_while()", dataset=self._input_dataset) if not wrapped_func.output_structure.is_compatible_with( structure_lib.TensorStructure(dtypes.bool, [])): raise ValueError( "`predicate` must return a scalar boolean tensor.") self._predicate = wrapped_func if compat.forward_compatible(2019, 8, 3): var_tensor = gen_experimental_dataset_ops.take_while_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **self._flat_structure) else: var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access other_arguments=self._predicate.function.captured_inputs, predicate=self._predicate.function, **self._flat_structure) super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" self._key_func = dataset_ops.StructuredFunctionWrapper( key_func, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (self._key_func.output_types, self._key_func.output_shapes))
def _make_key_func(self, key_func, input_dataset): """Make wrapping defun for key_func.""" def key_func_wrapper(*args): return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) self._key_func = dataset_ops.StructuredFunctionWrapper( key_func_wrapper, self._transformation_name(), dataset=input_dataset) if not self._key_func.output_structure.is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise ValueError( "`key_func` must return a single tf.int64 scalar tensor.")
def __init__(self, client_resource, selected_fields, output_types, avro_schema, stream): self._structure = structure.NestedStructure( tuple(structure.TensorStructure(dtype, []) for dtype in output_types)) variant_tensor = _bigquery_so.big_query_dataset( client=client_resource, selected_fields=selected_fields, output_types=output_types, avro_schema=avro_schema, stream=stream) super(_BigQueryDataset, self).__init__(variant_tensor)
def __init__(self, input_dataset, features, num_parallel_calls): super(_ParseExampleDataset, self).__init__(input_dataset) self._input_dataset = input_dataset if not input_dataset._element_structure.is_compatible_with( # pylint: disable=protected-access structure.TensorStructure(dtypes.string, [None])): raise TypeError("Input dataset should be a dataset of vectors of strings") self._num_parallel_calls = num_parallel_calls # pylint: disable=protected-access self._features = parsing_ops._prepend_none_dimension(features) # sparse_keys and dense_keys come back sorted here. (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = parsing_ops._features_to_raw_params( self._features, [ parsing_ops.VarLenFeature, parsing_ops.SparseFeature, parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature ]) # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, dense_shape_as_shape) = parsing_ops._process_raw_parameters( None, dense_defaults, sparse_keys, sparse_types, dense_keys, dense_types, dense_shapes) # pylint: enable=protected-access self._sparse_keys = sparse_keys self._sparse_types = sparse_types self._dense_keys = dense_keys self._dense_defaults = dense_defaults_vec self._dense_shapes = dense_shapes self._dense_types = dense_types dense_output_shapes = [ self._input_dataset.output_shapes.concatenate(shape) for shape in dense_shape_as_shape ] sparse_output_shapes = [ self._input_dataset.output_shapes.concatenate([None]) for _ in range(len(sparse_keys)) ] output_shapes = dict( zip(self._dense_keys + self._sparse_keys, dense_output_shapes + sparse_output_shapes)) output_types = dict( zip(self._dense_keys + self._sparse_keys, self._dense_types + self._sparse_types)) output_classes = dict( zip(self._dense_keys + self._sparse_keys, [ops.Tensor for _ in range(len(self._dense_defaults))] + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) ])) self._structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes)
def _make_reduce_func(self, reduce_func, input_dataset): """Make wrapping defun for reduce_func.""" nested_dataset = dataset_ops.DatasetStructure( input_dataset._element_structure) # pylint: disable=protected-access input_structure = structure.NestedStructure( (structure.TensorStructure(dtypes.int64, []), nested_dataset)) self._reduce_func = dataset_ops.StructuredFunctionWrapper( reduce_func, self._transformation_name(), input_structure=input_structure) if not isinstance( self._reduce_func.output_structure, dataset_ops.DatasetStructure): raise TypeError("`reduce_func` must return a `Dataset` object.") # pylint: disable=protected-access self._structure = ( self._reduce_func.output_structure._element_structure)
def __init__(self, driver_name, data_source_name, query, output_types): """Creates a `SqlDataset`. `SqlDataset` allows a user to read data from the result set of a SQL query. For example: ```python tf.compat.v1.enable_eager_execution() dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3", "SELECT name, age FROM people", (tf.string, tf.int32)) # Prints the rows of the result set of the above query. for element in dataset: print(element) ``` Args: driver_name: A 0-D `tf.string` tensor containing the database type. Currently, the only supported value is 'sqlite'. data_source_name: A 0-D `tf.string` tensor containing a connection string to connect to the database. query: A 0-D `tf.string` tensor containing the SQL query to execute. output_types: A tuple of `tf.DType` objects representing the types of the columns returned by `query`. """ self._driver_name = ops.convert_to_tensor(driver_name, dtype=dtypes.string, name="driver_name") self._data_source_name = ops.convert_to_tensor(data_source_name, dtype=dtypes.string, name="data_source_name") self._query = ops.convert_to_tensor(query, dtype=dtypes.string, name="query") self._structure = structure.NestedStructure( nest.map_structure( lambda dtype: structure.TensorStructure(dtype, []), output_types)) if compat.forward_compatible(2019, 8, 3): variant_tensor = gen_experimental_dataset_ops.sql_dataset( self._driver_name, self._data_source_name, self._query, **self._flat_structure) else: variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset( self._driver_name, self._data_source_name, self._query, **self._flat_structure) super(SqlDatasetV2, self).__init__(variant_tensor)
def choose_from_datasets_v2(datasets, choice_dataset): """Creates a dataset that deterministically chooses elements from `datasets`. For example, given the following datasets: ```python datasets = [tf.data.Dataset.from_tensors("foo").repeat(), tf.data.Dataset.from_tensors("bar").repeat(), tf.data.Dataset.from_tensors("baz").repeat()] # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. choice_dataset = tf.data.Dataset.range(3).repeat(3) result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) ``` The elements of `result` will be: ``` "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" ``` Args: datasets: A list of `tf.data.Dataset` objects with compatible structure. choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. Returns: A dataset that interleaves elements from `datasets` according to the values of `choice_dataset`. Raises: TypeError: If the `datasets` or `choice_dataset` arguments have the wrong type. """ if not structure.are_compatible( choice_dataset.element_spec, structure.TensorStructure(dtypes.int64, [])): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") return _DirectedInterleaveDataset(choice_dataset, datasets)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]), values=np.array([0], dtype=np.int64), dense_shape=np.array([1])) st_1 = sparse_tensor.SparseTensorValue( indices=np.array([[0, 0], [1, 1]]), values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) self.assertTrue(self.evaluate(opt.has_value())) val_0, val_1 = opt.get_value() for expected, actual in [(st_0, val_0), (st_1, val_1)]: self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) self.assertAllEqual(expected.values, self.evaluate(actual.values)) self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) def testFromNone(self): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops.Optional.none_from_structure(value_structure) self.assertTrue( opt.value_structure.is_compatible_with(value_structure)) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.float32, [1]))) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value( ) gpu_optional_with_value_values = gpu_optional_with_value.get_value( ) gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) def _assertElementValueEqual(self, expected, actual): if isinstance(expected, dict): self.assertItemsEqual(list(expected.keys()), list(actual.keys())) for k in expected.keys(): self._assertElementValueEqual(expected[k], actual[k]) elif isinstance(expected, sparse_tensor.SparseTensorValue): self.assertAllEqual(expected.indices, actual.indices) self.assertAllEqual(expected.values, actual.values) self.assertAllEqual(expected.dense_shape, actual.dense_shape) else: self.assertAllEqual(expected, actual) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure( dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) })), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testSkipEagerOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( expected_value_structure.is_compatible_with(opt.value_structure)) self.assertTrue( opt.value_structure.is_compatible_with(expected_value_structure)) opt_structure = structure.Structure.from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(opt_structure.is_compatible_with(opt_structure)) self.assertTrue( opt_structure._value_structure.is_compatible_with( expected_value_structure)) self.assertEqual([dtypes.variant], opt_structure._flat_types) self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.Structure.from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self.assertEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self.assertEqual(self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) @parameterized.named_parameters( ("Tensor", np.array([1, 2, 3], dtype=np.int32), lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), ("SparseTensor", sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), False), ("Nest", { "a": np.array([1, 2, 3], dtype=np.int32), "b": sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]) }, lambda: { "a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), "b": sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]) }, False), ) def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu): if not works_on_gpu and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) iterator = ds.make_initializable_iterator() next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( next_elem.value_structure.is_compatible_with( structure.Structure.from_value(tf_value_fn()))) elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.cached_session() as sess: # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): sess.run(elem_value_t) # For each element of the dataset, assert that the optional evaluates to # the expected value. sess.run(iterator.initializer) for _ in range(3): elem_has_value, elem_value = sess.run( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self._assertElementValueEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(sess.run(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): sess.run(elem_value_t)
def element_spec(self): return structure.TensorStructure(dtypes.string, [])
def __init__(self, filenames, record_defaults, compression_type=None, buffer_size=None, header=False, field_delim=",", use_quote_delim=True, na_value="", select_cols=None): """Creates a `CsvDataset` by reading and decoding CSV files. The elements of this dataset correspond to records from the file(s). RFC 4180 format is expected for CSV files (https://tools.ietf.org/html/rfc4180) Note that we allow leading and trailing spaces with int or float field. For example, suppose we have a file 'my_file0.csv' with four CSV columns of different data types: ``` abcdefg,4.28E10,5.55E6,12 hijklmn,-5.3E14,,2 ``` We can construct a CsvDataset from it as follows: ```python tf.compat.v1.enable_eager_execution() dataset = tf.data.experimental.CsvDataset( "my_file*.csv", [tf.float32, # Required field, use dtype or empty tensor tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0 tf.int32, # Required field, use dtype or empty tensor ], select_cols=[1,2,3] # Only parse last three columns ) ``` The expected output of its iterations is: ```python for element in dataset: print(element) >> (4.28e10, 5.55e6, 12) >> (-5.3e14, 0.0, 2) ``` Args: filenames: A `tf.string` tensor containing one or more filenames. record_defaults: A list of default values for the CSV fields. Each item in the list is either a valid CSV `DType` (float32, float64, int32, int64, string), or a `Tensor` object with one of the above types. One per column of CSV data, with either a scalar `Tensor` default value for the column if it is optional, or `DType` or empty `Tensor` if required. If both this and `select_columns` are specified, these must have the same lengths, and `column_defaults` is assumed to be sorted in order of increasing column index. compression_type: (Optional.) A `tf.string` scalar evaluating to one of `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression. buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes to buffer while reading files. Defaults to 4MB. header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s) have header line(s) that should be skipped when parsing. Defaults to `False`. field_delim: (Optional.) A `tf.string` scalar containing the delimiter character that separates fields in a record. Defaults to `","`. use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats double quotation marks as regular characters inside of string fields (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`. na_value: (Optional.) A `tf.string` scalar indicating a value that will be treated as NA/NaN. select_cols: (Optional.) A sorted list of column indices to select from the input data. If specified, only this subset of columns will be parsed. Defaults to parsing all columns. """ self._filenames = ops.convert_to_tensor( filenames, dtype=dtypes.string, name="filenames") self._compression_type = convert.optional_param_to_tensor( "compression_type", compression_type, argument_default="", argument_dtype=dtypes.string) record_defaults = [ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x for x in record_defaults ] self._record_defaults = ops.convert_n_to_tensor( record_defaults, name="record_defaults") self._buffer_size = convert.optional_param_to_tensor( "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES) self._header = ops.convert_to_tensor( header, dtype=dtypes.bool, name="header") self._field_delim = ops.convert_to_tensor( field_delim, dtype=dtypes.string, name="field_delim") self._use_quote_delim = ops.convert_to_tensor( use_quote_delim, dtype=dtypes.bool, name="use_quote_delim") self._na_value = ops.convert_to_tensor( na_value, dtype=dtypes.string, name="na_value") self._select_cols = convert.optional_param_to_tensor( "select_cols", select_cols, argument_default=[], argument_dtype=dtypes.int64, ) self._structure = structure.NestedStructure( tuple(structure.TensorStructure(d.dtype, []) for d in self._record_defaults)) variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset( filenames=self._filenames, record_defaults=self._record_defaults, buffer_size=self._buffer_size, header=self._header, output_shapes=self._structure._flat_shapes, # pylint: disable=protected-access field_delim=self._field_delim, use_quote_delim=self._use_quote_delim, na_value=self._na_value, select_cols=self._select_cols, compression_type=self._compression_type) super(CsvDatasetV2, self).__init__(variant_tensor)
def _element_structure(self): return (structure.TensorStructure(dtypes.string, []), structure.TensorStructure(dtypes.string, []))
def _element_structure(self): return structure.NestedStructure( tuple([structure.TensorStructure(dtypes.string, [])] * self._num_outputs))
def _element_structure(self): return structure.TensorStructure(dtypes.int64, [])
def consume_optional(opt_tensor): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops._OptionalImpl(opt_tensor, value_structure) return opt.get_value()
class IteratorTest(test.TestCase, parameterized.TestCase): def testNoGradients(self): component = constant_op.constant([1.]) side = constant_op.constant(0.) add = lambda x: x + side dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) value = dataset.make_one_shot_iterator().get_next() self.assertIsNone(gradients_impl.gradients(value, component)[0]) self.assertIsNone(gradients_impl.gradients(value, side)[0]) self.assertIsNone(gradients_impl.gradients(value, [component, side])[0]) def testCapturingStateInOneShotRaisesException(self): var = variables.Variable(37.0, name="myvar") dataset = ( dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0]) .map(lambda x: x + var)) with self.assertRaisesRegexp( ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support " "datasets that capture stateful objects.+myvar"): dataset.make_one_shot_iterator() def testOneShotIterator(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = ( dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) .repeat(14).make_one_shot_iterator()) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOneShotIteratorCaptureByValue(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = ( dataset_ops.Dataset.from_tensor_slices(tensor_components) .map(_map_fn).repeat(14).make_one_shot_iterator()) get_next = iterator.get_next() self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOneShotIteratorInsideContainer(self): components = (np.arange(7), np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], np.array(37.0) * np.arange(7)) def within_container(): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(_map_fn).repeat(14).make_one_shot_iterator()) return iterator.get_next() server = server_lib.Server.create_local_server() # Create two iterators within unique containers, and run them to # make sure that the resources aren't shared. # # The test below would fail if cname were the same across both # sessions. for j in range(2): with session.Session(server.target) as sess: cname = "iteration%d" % j with ops.container(cname): get_next = within_container() for _ in range(14): for i in range(7): result = sess.run(get_next) for component, result_component in zip(components, result): self.assertAllEqual(component[i]**2, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testOneShotIteratorNonBlocking(self): dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() # Create a session with a single thread to ensure that the # one-shot iterator initializer does not deadlock. config = config_pb2.ConfigProto( inter_op_parallelism_threads=1, use_per_session_threads=True) with session.Session(config=config) as sess: self.assertAllEqual([1, 4, 9], sess.run(next_element)) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) # Test with multiple threads invoking the one-shot iterator concurrently. with session.Session(config=config) as sess: results = [] def consumer_thread(): try: results.append(sess.run(next_element)) except errors.OutOfRangeError: results.append(None) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() self.assertEqual(num_threads, len(results)) self.assertEqual(num_threads - 1, len([None for r in results if r is None])) self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) def testOneShotIteratorInitializerFails(self): # Define a dataset whose initialization will always fail. dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) # Test that subsequent attempts to use the iterator also fail. with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) num_threads = 8 threads = [ self.checkedThread(consumer_thread) for _ in range(num_threads) ] for t in threads: t.start() for t in threads: t.join() def testSimpleSharedResource(self): components = (np.array(1, dtype=np.int64), np.array([1, 2, 3], dtype=np.int64), np.array(37.0, dtype=np.float64)) server = server_lib.Server.create_local_server() # Create two non-overlapping sessions that share the same iterator # resource on the same server, and verify that an action of the # first session (initializing the iterator) is visible in the # second session. with ops.Graph().as_default(): iterator = ( dataset_ops.Dataset.from_tensors(components) .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( shared_name="shared_iterator")) init_op = iterator.initializer get_next = iterator.get_next() with session.Session(server.target) as sess: sess.run(init_op) results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Re-initialize the iterator in the first session. sess.run(init_op) with ops.Graph().as_default(): # Re-define the iterator manually, without defining any of the # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) get_next = iterator.get_next() with session.Session(server.target) as sess: # Use the iterator without re-initializing in the second session. results = sess.run(get_next) for component, result_component in zip(components, results): self.assertAllEqual(component, result_component) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testNotInitializedError(self): components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) iterator = ( dataset_ops.Dataset.from_tensors(components) .make_initializable_iterator()) get_next = iterator.get_next() with self.cached_session() as sess: with self.assertRaisesRegexp(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) def testReinitializableIterator(self): dataset_3 = dataset_ops.Dataset.from_tensors( constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) get_next = iterator.get_next() self.assertEqual(dataset_3.output_types, iterator.output_types) self.assertEqual(dataset_4.output_types, iterator.output_types) self.assertEqual([None], iterator.output_shapes.as_list()) with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) # Initialize with one dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Initialize with a different dataset. sess.run(dataset_4_init_op) self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Reinitialize with the first dataset. sess.run(dataset_3_init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), [None]) # Test validation of dataset argument. iterator = iterator_ops.Iterator.from_structure((dtypes.int64, dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): iterator.make_initializer( dataset_ops.Dataset.from_tensors(((constant_op.constant( [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( [4., 5., 6., 7.], dtype=dtypes.float64),)))) # Incompatible types. with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int32), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( dataset_ops.Dataset.from_tensors( (constant_op.constant([1, 2, 3], dtype=dtypes.int64), constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64)))) def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_4 = dataset_4.make_one_shot_iterator() handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual(10, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(1, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(20, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(2, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(30, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(3, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(40, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) def testIteratorStringHandleFuture(self): with forward_compat.forward_compatibility_horizon(2018, 8, 4): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_4 = dataset_4.make_one_shot_iterator() handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle}) def testIteratorStringHandleReuseTensorObject(self): dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) one_shot_iterator = dataset.make_one_shot_iterator() initializable_iterator = dataset.make_initializable_iterator() structure_iterator = iterator_ops.Iterator.from_structure( dataset.output_types) created_ops = len(ops.get_default_graph().get_operations()) self.assertIs(one_shot_iterator.string_handle(), one_shot_iterator.string_handle()) self.assertIs(initializable_iterator.string_handle(), initializable_iterator.string_handle()) self.assertIs(structure_iterator.string_handle(), structure_iterator.string_handle()) # Assert that getting the (default) string handle creates no ops. self.assertEqual(created_ops, len(ops.get_default_graph().get_operations())) # Specifying an explicit name will create a new op. handle_with_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo", handle_with_name.op.name) self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name) handle_with_same_name = one_shot_iterator.string_handle(name="foo") self.assertEqual("foo_1", handle_with_same_name.op.name) self.assertIsNot(handle_with_name, handle_with_same_name) def testIteratorStringHandleError(self): dataset_int_scalar = ( dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat()) dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_int_scalar.make_one_shot_iterator().string_handle()) handle_float_vector = sess.run( dataset_float_vector.make_one_shot_iterator().string_handle()) self.assertEqual(1, sess.run( feedable_int_scalar.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) self.assertEqual(2, sess.run( feedable_int_any.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print(sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_int_scalar})) with self.assertRaises(errors.InvalidArgumentError): print(sess.run( feedable_int_vector.get_next(), feed_dict={handle_placeholder: handle_float_vector})) def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/cpu:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call( args=[iterator_3_handle], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.session(config=worker_config) as sess: elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [2]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self): s1 = server_lib.Server.create_local_server() s2 = server_lib.Server.create_local_server() s3 = server_lib.Server.create_local_server() cluster_def = cluster_pb2.ClusterDef() workers = cluster_def.job.add() workers.name = "worker" workers.tasks[0] = s1.target[len("grpc://"):] workers.tasks[1] = s2.target[len("grpc://"):] client = cluster_def.job.add() client.name = "client" client.tasks[0] = s3.target[len("grpc://"):] config = config_pb2.ConfigProto(cluster_def=cluster_def) worker_devices = [ "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2) ] itr_handles = [] for device in worker_devices: with ops.device(device): src = dataset_ops.Dataset.from_tensor_slices([device]) itr = src.make_one_shot_iterator() itr_handles.append(itr.string_handle()) targets = dataset_ops.Dataset.from_tensor_slices(worker_devices) handles = dataset_ops.Dataset.from_tensor_slices(itr_handles) @function.Defun(dtypes.string) def loading_func(h): remote_itr = iterator_ops.Iterator.from_string_handle( h, itr.output_types, itr.output_shapes) return remote_itr.get_next() def map_fn(target, handle): return functional_ops.remote_call( args=[handle], Tout=[dtypes.string], f=loading_func, target=target) with ops.device("/job:client"): client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn) itr = client_dataset.make_initializable_iterator() n = itr.get_next() with session.Session(s3.target, config=config) as sess: sess.run(itr.initializer) expected_values = worker_devices for expected in expected_values: self.assertEqual((compat.as_bytes(expected),), sess.run(n)) with self.assertRaises(errors.OutOfRangeError): sess.run(n) def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/job:localhost/replica:0/task:0/cpu:0"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_3_handle = iterator_3.string_handle() def _encode_raw(byte_array): return bytes(bytearray(byte_array)) @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) iterator_3_handle_uint8 = parsing_ops.decode_raw( bytes=iterator_3_handle, out_type=dtypes.uint8) remote_op = functional_ops.remote_call( args=[iterator_3_handle_uint8], Tout=[dtypes.int32], f=_remote_fn, target=target_placeholder) with self.cached_session() as sess: elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [1]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [2]) elem = sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): sess.run( remote_op, feed_dict={ target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) def testIncorrectIteratorRestore(self): def _path(): return os.path.join(self.get_temp_dir(), "iterator") def _save_op(iterator_resource): iterator_state_variant = gen_dataset_ops.serialize_iterator( iterator_resource) save_op = io_ops.write_file( _path(), parsing_ops.serialize_tensor(iterator_state_variant)) return save_op def _restore_op(iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( io_ops.read_file(_path()), dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op def _build_range_dataset_graph(): start = 1 stop = 10 iterator = dataset_ops.Dataset.range(start, stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() save_op = _save_op(iterator._iterator_resource) restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op def _build_reader_dataset_graph(): filenames = ["test"] # Does not exist but we don't care in this test. iterator = readers.FixedLengthRecordDataset( filenames, 1, 0, 0).make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() save_op = _save_op(iterator._iterator_resource) restore_op = _restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op # Saving iterator for RangeDataset graph. with ops.Graph().as_default() as g: init_op, _, save_op, _ = _build_range_dataset_graph() with self.session(graph=g) as sess: sess.run(init_op) sess.run(save_op) # Attempt to restore the saved iterator into an IteratorResource of # incompatible type. An iterator of RangeDataset has output type int64, # while an iterator of FixedLengthRecordDataset has output type string. # So an InvalidArgumentError should be raised by # IteratorResource::set_iterator. with ops.Graph().as_default() as g: _, _, _, restore_op = _build_reader_dataset_graph() with self.session(graph=g) as sess: with self.assertRaises(errors.InvalidArgumentError): sess.run(restore_op) def testRepeatedGetNextWarning(self): iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator() warnings.simplefilter("always") with warnings.catch_warnings(record=True) as w: for _ in range(100): iterator.get_next() self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w)) for warning in w: self.assertIn( iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message)) def testEagerIteratorAsync(self): with context.eager_mode(), context.execution_mode(context.ASYNC): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, []), ops.Tensor, dtypes.float32, []), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1]), sparse_tensor.SparseTensor, dtypes.int32, [1]), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))}), {"a": ops.Tensor, "b": (ops.Tensor, ops.Tensor)}, {"a": dtypes.float32, "b": (dtypes.string, dtypes.string)}, {"a": [], "b": ([1], [])}), ) def testIteratorStructure(self, tf_value_fn, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value = tf_value_fn() iterator = dataset_ops.Dataset.from_tensors( tf_value).make_one_shot_iterator() self.assertTrue(expected_element_structure.is_compatible_with( iterator._element_structure)) self.assertTrue(iterator._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual(expected_output_classes, iterator.output_classes) self.assertEqual(expected_output_types, iterator.output_types) self.assertEqual(expected_output_shapes, iterator.output_shapes)
def _make_init_func(self, init_func): """Make wrapping defun for init_func.""" self._init_func = dataset_ops.StructuredFunctionWrapper( init_func, self._transformation_name(), input_structure=structure.TensorStructure(dtypes.int64, []))
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been # set up. # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( (lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), structure.TensorArrayStructure, [dtypes.variant], [None, 3]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3] ]), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None])) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) for expected, actual in zip(expected_shapes, s._flat_shapes): self.assertTrue(actual.is_compatible_with(expected)) self.assertTrue( tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.parameters( (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=10) ], lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=0) ]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) @test_util.run_deprecated_v1 def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) @parameterized.parameters( (lambda: constant_op.constant(37.0), ), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), (lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list( value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("TensorArray0", dtypes.int32, tensor_shape.as_shape( [None, True, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), ("TensorArray1", dtypes.int32, tensor_shape.as_shape( [True, None, 2, 2]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), ("TensorArray2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("Nest", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure)) def testNestedNestedStructure(self): # Although `Structure.from_value()` will not construct one, a nested # structure containing nested `NestedStructure` objects can occur if a # structure is constructed manually. s = structure.NestedStructure( (structure.TensorStructure(dtypes.int64, []), structure.NestedStructure( (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = s._to_tensor_list(nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = s._from_tensor_list(tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (s._from_compatible_tensor_list(tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, []), 32, structure.TensorStructure(dtypes.float32, [32])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None, structure.TensorStructure(dtypes.float32, [None])), ("SparseTensor", structure.SparseTensorStructure( dtypes.float32, [None]), 32, structure.SparseTensorStructure(dtypes.float32, [32, None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [4]), None, structure.SparseTensorStructure(dtypes.float32, [None, 4])), ("Nest", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), 128, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [128])) })), ) def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = element_structure._batch(batch_size) self.assertTrue( batched_structure.is_compatible_with(expected_batched_structure)) self.assertTrue( expected_batched_structure.is_compatible_with(batched_structure)) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, [32]), structure.TensorStructure(dtypes.float32, [])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [32, None]), structure.SparseTensorStructure(dtypes.float32, [None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [None, 4]), structure.SparseTensorStructure(dtypes.float32, [4])), ("Nest", structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [None])) }), structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = element_structure._unbatch() self.assertTrue( unbatched_structure.is_compatible_with( expected_unbatched_structure)) self.assertTrue( expected_unbatched_structure.is_compatible_with( unbatched_structure)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), lambda: constant_op.constant([1.0, 2.0])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2])), ("Nest", lambda: (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2]))), ) def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.Structure.from_value(batched_value) batched_tensor_list = s._to_batched_tensor_list(batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = s._unbatch() actual_element_0 = unbatched_s._from_tensor_list( [t[0] for t in batched_tensor_list]) for expected, actual in zip(nest.flatten(expected_element_0), nest.flatten(actual_element_0)): if sparse_tensor.is_sparse(expected): self.assertSparseValuesEqual(expected, actual) else: self.assertAllEqual(expected, actual)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase): def testFromValue(self): opt = optional_ops.Optional.from_value(constant_op.constant(37.0)) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual(37.0, self.evaluate(opt.get_value())) def testFromStructuredValue(self): opt = optional_ops.Optional.from_value({ "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }) self.assertTrue(self.evaluate(opt.has_value())) self.assertEqual({ "a": 37.0, "b": ([b"Foo"], b"Bar") }, self.evaluate(opt.get_value())) def testFromSparseTensor(self): st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]), values=np.array([0], dtype=np.int64), dense_shape=np.array([1])) st_1 = sparse_tensor.SparseTensorValue( indices=np.array([[0, 0], [1, 1]]), values=np.array([-1., 1.], dtype=np.float32), dense_shape=np.array([2, 2])) opt = optional_ops.Optional.from_value((st_0, st_1)) self.assertTrue(self.evaluate(opt.has_value())) val_0, val_1 = opt.get_value() for expected, actual in [(st_0, val_0), (st_1, val_1)]: self.assertAllEqual(expected.indices, self.evaluate(actual.indices)) self.assertAllEqual(expected.values, self.evaluate(actual.values)) self.assertAllEqual(expected.dense_shape, self.evaluate(actual.dense_shape)) def testFromNone(self): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops.Optional.none_from_structure(value_structure) self.assertTrue( opt.value_structure.is_compatible_with(value_structure)) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.float32, [1]))) self.assertFalse( opt.value_structure.is_compatible_with( structure.TensorStructure(dtypes.int32, []))) self.assertFalse(self.evaluate(opt.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(opt.get_value()) def testAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt1 = optional_ops.Optional.from_value((1.0, 2.0)) opt2 = optional_ops.Optional.from_value((3.0, 4.0)) add_tensor = math_ops.add_n( [opt1._variant_tensor, opt2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure) self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0)) # Without value opt_none1 = optional_ops.Optional.none_from_structure( opt1.value_structure) opt_none2 = optional_ops.Optional.none_from_structure( opt2.value_structure) add_tensor = math_ops.add_n( [opt_none1._variant_tensor, opt_none2._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt_none1.value_structure) self.assertFalse(self.evaluate(add_opt.has_value())) def testNestedAddN(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value([1, 2.0]) opt2 = optional_ops.Optional.from_value([3, 4.0]) opt3 = optional_ops.Optional.from_value( (5.0, opt1._variant_tensor)) opt4 = optional_ops.Optional.from_value( (6.0, opt2._variant_tensor)) add_tensor = math_ops.add_n( [opt3._variant_tensor, opt4._variant_tensor]) add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure) self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0) inner_add_opt = optional_ops._OptionalImpl( add_opt.get_value()[1], opt1.value_structure) self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0]) def testZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): # With value opt = optional_ops.Optional.from_value((1.0, 2.0)) zeros_tensor = array_ops.zeros_like(opt._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt.value_structure) self.assertAllEqual(self.evaluate(zeros_opt.get_value()), (0.0, 0.0)) # Without value opt_none = optional_ops.Optional.none_from_structure( opt.value_structure) zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor) zeros_opt = optional_ops._OptionalImpl( zeros_tensor, opt_none.value_structure) self.assertFalse(self.evaluate(zeros_opt.has_value())) def testNestedZerosLike(self): devices = ["/cpu:0"] if test_util.is_gpu_available(): devices.append("/gpu:0") for device in devices: with ops.device(device): opt1 = optional_ops.Optional.from_value(1.0) opt2 = optional_ops.Optional.from_value(opt1._variant_tensor) zeros_tensor = array_ops.zeros_like(opt2._variant_tensor) zeros_opt = optional_ops._OptionalImpl(zeros_tensor, opt2.value_structure) inner_zeros_opt = optional_ops._OptionalImpl( zeros_opt.get_value(), opt1.value_structure) self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0) def testCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) with ops.device("/gpu:0"): gpu_optional_with_value = optional_ops._OptionalImpl( array_ops.identity(optional_with_value._variant_tensor), optional_with_value.value_structure) gpu_optional_none = optional_ops._OptionalImpl( array_ops.identity(optional_none._variant_tensor), optional_none.value_structure) gpu_optional_with_value_has_value = gpu_optional_with_value.has_value( ) gpu_optional_with_value_values = gpu_optional_with_value.get_value( ) gpu_optional_none_has_value = gpu_optional_none.has_value() self.assertTrue(self.evaluate(gpu_optional_with_value_has_value)) self.assertEqual((37.0, b"Foo", 42), self.evaluate(gpu_optional_with_value_values)) self.assertFalse(self.evaluate(gpu_optional_none_has_value)) def testNestedCopyToGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPU available") with ops.device("/cpu:0"): optional_with_value = optional_ops.Optional.from_value( (constant_op.constant(37.0), constant_op.constant("Foo"), constant_op.constant(42))) optional_none = optional_ops.Optional.none_from_structure( structure.TensorStructure(dtypes.float32, [])) nested_optional = optional_ops.Optional.from_value( (optional_with_value._variant_tensor, optional_none._variant_tensor, 1.0)) with ops.device("/gpu:0"): gpu_nested_optional = optional_ops._OptionalImpl( array_ops.identity(nested_optional._variant_tensor), nested_optional.value_structure) gpu_nested_optional_has_value = gpu_nested_optional.has_value() gpu_nested_optional_values = gpu_nested_optional.get_value() self.assertTrue(self.evaluate(gpu_nested_optional_has_value)) inner_with_value = optional_ops._OptionalImpl( gpu_nested_optional_values[0], optional_with_value.value_structure) inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1], optional_none.value_structure) self.assertEqual((37.0, b"Foo", 42), self.evaluate(inner_with_value.get_value())) self.assertFalse(self.evaluate(inner_none.has_value())) self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2])) def _assertElementValueEqual(self, expected, actual): if isinstance(expected, dict): self.assertItemsEqual(list(expected.keys()), list(actual.keys())) for k in expected.keys(): self._assertElementValueEqual(expected[k], actual[k]) elif isinstance(expected, sparse_tensor.SparseTensorValue): self.assertAllEqual(expected.indices, actual.indices) self.assertAllEqual(expected.values, actual.values) self.assertAllEqual(expected.dense_shape, actual.dense_shape) else: self.assertAllEqual(expected, actual) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 1]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[10, 10]), structure.SparseTensorStructure(dtypes.int32, [10, 10])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) }, { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, [])) }), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( structure.are_compatible(opt.value_structure, expected_value_structure)) opt_structure = structure.type_spec_from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(structure.are_compatible(opt_structure, opt_structure)) self.assertTrue( structure.are_compatible(opt_structure._value_structure, expected_value_structure)) self.assertEqual([dtypes.variant], opt_structure._flat_types) self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.type_spec_from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self._assertElementValueEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self._assertElementValueEqual( self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value())) @parameterized.named_parameters( ("Tensor", np.array([1, 2, 3], dtype=np.int32), lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True), ("SparseTensor", sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]), False), ("Nest", { "a": np.array([1, 2, 3], dtype=np.int32), "b": sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]], values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]) }, lambda: { "a": constant_op.constant([4, 5, 6], dtype=dtypes.int32), "b": sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]) }, False), ) def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, works_on_gpu): if not works_on_gpu and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) if context.executing_eagerly(): iterator = dataset_ops.make_one_shot_iterator(ds) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.value_structure, structure.type_spec_from_value(tf_value_fn()))) self.assertTrue(next_elem.has_value()) self._assertElementValueEqual(np_value, next_elem.get_value()) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertFalse(self.evaluate(next_elem.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_elem.get_value()) else: iterator = dataset_ops.make_initializable_iterator(ds) next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.value_structure, structure.type_spec_from_value(tf_value_fn()))) # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. This is only relevant in graph mode. elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_value_t) # Now we initialize the iterator. self.evaluate(iterator.initializer) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): elem_has_value, elem_value = self.evaluate( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self._assertElementValueEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(self.evaluate(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_value_t) def testFunctionBoundaries(self): @def_function.function def get_optional(): x = constant_op.constant(1.0) opt = optional_ops.Optional.from_value(x) # TODO(skyewm): support returning Optionals from functions? return opt._variant_tensor # TODO(skyewm): support Optional arguments? @def_function.function def consume_optional(opt_tensor): value_structure = structure.TensorStructure(dtypes.float32, []) opt = optional_ops._OptionalImpl(opt_tensor, value_structure) return opt.get_value() opt_tensor = get_optional() val = consume_optional(opt_tensor) self.assertEqual(self.evaluate(val), 1.0) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(opt): trace_count[0] += 1 return opt.get_value() opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0)) opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0)) for _ in range(10): self.assertEqual(self.evaluate(f(opt1)), 37.0) self.assertEqual(self.evaluate(f(opt2)), 42.0) self.assertEqual(trace_count[0], 1)
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase, test_util.TensorFlowTestCase): # pylint: disable=g-long-lambda,protected-access @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec, [dtypes.float32], [[]]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]), ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]), ("Nested_0", lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_1", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, dict, [dtypes.float32, dtypes.int32], [[], [3]]), ("Nested_2", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]), ) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.type_spec_from_value(value) self.assertIsInstance(s, expected_structure) flat_types = structure.get_flat_tensor_types(s) self.assertEqual(expected_types, flat_types) flat_shapes = structure.get_flat_tensor_shapes(s) self.assertLen(flat_shapes, len(expected_shapes)) for expected, actual in zip(expected_shapes, flat_shapes): if expected is None: self.assertEqual(actual.ndims, None) else: self.assertEqual(actual.as_list(), expected) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=10) ], lambda: [ tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(3, ), size=0), tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=0) ]), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [ ragged_factory_ops.constant([[1, 2], [3, 4], []]), ragged_factory_ops.constant([[1], [2, 3, 4], [5]]), ], lambda: [ ragged_factory_ops.constant(1), ragged_factory_ops.constant([1, 2]), ragged_factory_ops.constant([[1], [2]]), ragged_factory_ops.constant([["a", "b"]]), ]), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) @test_util.run_deprecated_v1 def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.type_spec_from_value(original_value) for compatible_value in compatible_values: self.assertTrue( structure.are_compatible( s, structure.type_spec_from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( structure.are_compatible( s, structure.type_spec_from_value(incompatible_value))) @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[3]], values=[-1], dense_shape=[5]), lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]), lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1), lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]), lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6]) }, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) # pyformat: disable def testStructureFromValueEquality(self, value1_fn, value2_fn, *not_equal_value_fns): # pylint: disable=g-generic-assert s1 = structure.type_spec_from_value(value1_fn()) s2 = structure.type_spec_from_value(value2_fn()) self.assertEqual(s1, s1) # check __eq__ operator. self.assertEqual(s1, s2) # check __eq__ operator. self.assertFalse(s1 != s1) # check __ne__ operator. self.assertFalse(s1 != s2) # check __ne__ operator. for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)): self.assertEqual(hash(c1), hash(c1)) self.assertEqual(hash(c1), hash(c2)) for value_fn in not_equal_value_fns: s3 = structure.type_spec_from_value(value_fn()) self.assertNotEqual(s1, s3) # check __ne__ operator. self.assertNotEqual(s2, s3) # check __ne__ operator. self.assertFalse(s1 == s3) # check __eq_ operator. self.assertFalse(s2 == s3) # check __eq_ operator. @parameterized.named_parameters( ("RaggedTensor_RaggedRank", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.int32, None, 2)), ("RaggedTensor_Shape", structure.RaggedTensorStructure(dtypes.int32, [3, None], 1), structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)), ("RaggedTensor_DType", structure.RaggedTensorStructure(dtypes.int32, None, 1), structure.RaggedTensorStructure(dtypes.float32, None, 1)), ) def testRaggedStructureInequality(self, s1, s2): # pylint: disable=g-generic-assert self.assertNotEqual(s1, s2) # check __ne__ operator. self.assertFalse(s1 == s2) # check __eq__ operator. @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: tensor_array_ops.TensorArray( dtype=dtypes.int32, element_shape=(), size=0)), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: sparse_tensor.SparseTensor( indices=[[3]], values=[-1], dense_shape=[5])), ("Nested", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: { "a": constant_op.constant(42.0), "b": constant_op.constant([4, 5, 6]) }, lambda: { "a": constant_op.constant([1, 2, 3]), "b": constant_op.constant(37.0) }), ) def testHash(self, value1_fn, value2_fn, value3_fn): s1 = structure.type_spec_from_value(value1_fn()) s2 = structure.type_spec_from_value(value2_fn()) s3 = structure.type_spec_from_value(value3_fn()) for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)): self.assertEqual(hash(c1), hash(c1)) self.assertEqual(hash(c1), hash(c2)) self.assertNotEqual(hash(c1), hash(c3)) self.assertNotEqual(hash(c2), hash(c3)) @parameterized.named_parameters( ( "Tensor", lambda: constant_op.constant(37.0), ), ( "SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), ("TensorArray", lambda: tensor_array_ops.TensorArray( dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), ( "RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), ), ( "Nested_0", lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), ( "Nested_1", lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor( indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.type_spec_from_value(value) def maybe_stack_ta(v): if isinstance(v, tensor_array_ops.TensorArray): return v.stack() else: return v before = self.evaluate(maybe_stack_ta(value)) after = self.evaluate( maybe_stack_ta( structure.from_tensor_list(s, structure.to_tensor_list(s, value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) elif isinstance(b, (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): self.assertAllEqual(b, a) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def preserveStaticShape(self): rt = ragged_factory_ops.constant([[1, 2], [], [3]]) rt_s = structure.type_spec_from_value(rt) rt_after = structure.from_tensor_list( rt_s, structure.to_tensor_list(rt_s, rt)) self.assertEqual(rt_after.row_splits.shape.as_list(), rt.row_splits.shape.as_list()) self.assertEqual(rt_after.values.shape.as_list(), [None]) st = sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5]) st_s = structure.type_spec_from_value(st) st_after = structure.from_tensor_list( st_s, structure.to_tensor_list(st_s, st)) self.assertEqual(st_after.indices.shape.as_list(), [None, 2]) self.assertEqual(st_after.values.shape.as_list(), [None]) self.assertEqual(st_after.dense_shape.shape.as_list(), st.dense_shape.shape.as_list()) def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.type_spec_from_value(value_tensor) flat_tensor = structure.to_tensor_list(s_tensor, value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor) flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor, value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.type_spec_from_value(value_nest) flat_nest = structure.to_tensor_list(s_nest, value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): structure.to_tensor_list(s_tensor, value_sparse_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_tensor, value_nest) with self.assertRaisesRegexp( TypeError, "Neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_sparse_tensor, value_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_sparse_tensor, value_nest) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_tensor) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_nest, value_sparse_tensor) with self.assertRaisesRegexp(ValueError, r"Incompatible input:"): structure.from_tensor_list(s_tensor, flat_sparse_tensor) with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_tensor, flat_nest) with self.assertRaisesRegexp(ValueError, "Incompatible input: "): structure.from_tensor_list(s_sparse_tensor, flat_tensor) with self.assertRaisesRegexp(ValueError, "Expected 1 tensors but got 2."): structure.from_tensor_list(s_sparse_tensor, flat_nest) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_tensor) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 1."): structure.from_tensor_list(s_nest, flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructure a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.type_spec_from_value(value_0) flat_s_0 = structure.to_tensor_list(s_0, value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.type_spec_from_value(value_1) flat_s_1 = structure.to_tensor_list(s_1, value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.type_spec_from_value(value_2) flat_s_2 = structure.to_tensor_list(s_2, value_2) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*int32.* and shape \(3,\)"): structure.to_tensor_list(s_0, value_1) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_0, value_2) with self.assertRaisesRegexp( TypeError, "Neither a SparseTensor nor SparseTensorValue"): structure.to_tensor_list(s_1, value_0) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_1, value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_0) with self.assertRaisesRegexp( ValueError, "The two structures don't have the same nested structure."): structure.to_tensor_list(s_2, value_1) with self.assertRaisesRegexp(ValueError, r"Incompatible input:"): structure.from_tensor_list(s_0, flat_s_1) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."): structure.from_tensor_list(s_0, flat_s_2) with self.assertRaisesRegexp(ValueError, "Incompatible input: "): structure.from_tensor_list(s_1, flat_s_0) with self.assertRaisesRegexp(ValueError, "Expected 2 tensors but got 3."): structure.from_tensor_list(s_1, flat_s_2) with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."): structure.from_tensor_list(s_2, flat_s_0) with self.assertRaisesRegexp(ValueError, "Expected 3 tensors but got 2."): structure.from_tensor_list(s_2, flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("TensorArray_0", dtypes.int32, tensor_shape.as_shape([None, True, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), ("TensorArray_1", dtypes.int32, tensor_shape.as_shape([True, None, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), ("TensorArray_2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2 ]), tensor_array_ops.TensorArray, structure.TensorArrayStructure( dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None), structure.RaggedTensorStructure(dtypes.int32, [2, None], 1), structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)), ("Nested", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), ) def testConvertLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.convert_legacy_structure( output_types, output_shapes, output_classes) self.assertEqual(actual_structure, expected_structure) def testNestedNestedStructure(self): s = (structure.TensorStructure(dtypes.int64, []), (structure.TensorStructure(dtypes.float32, []), structure.TensorStructure(dtypes.string, []))) int64_t = constant_op.constant(37, dtype=dtypes.int64) float32_t = constant_op.constant(42.0) string_t = constant_op.constant("Foo") nested_tensors = (int64_t, (float32_t, string_t)) tensor_list = structure.to_tensor_list(s, nested_tensors) for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): self.assertIs(expected, actual) (actual_int64_t, (actual_float32_t, actual_string_t)) = structure.from_tensor_list(s, tensor_list) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) (actual_int64_t, (actual_float32_t, actual_string_t)) = (structure.from_compatible_tensor_list( s, tensor_list)) self.assertIs(int64_t, actual_int64_t) self.assertIs(float32_t, actual_float32_t) self.assertIs(string_t, actual_string_t) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, []), 32, structure.TensorStructure(dtypes.float32, [32])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None, structure.TensorStructure(dtypes.float32, [None])), ("SparseTensor", structure.SparseTensorStructure( dtypes.float32, [None]), 32, structure.SparseTensorStructure(dtypes.float32, [32, None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [4]), None, structure.SparseTensorStructure(dtypes.float32, [None, 4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32, structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None, structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)), ("Nested", { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }, 128, { "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [128])) }), ) def testBatch(self, element_structure, batch_size, expected_batched_structure): batched_structure = nest.map_structure( lambda component_spec: component_spec._batch(batch_size), element_structure) self.assertEqual(batched_structure, expected_batched_structure) @parameterized.named_parameters( ("Tensor", structure.TensorStructure(dtypes.float32, [32]), structure.TensorStructure(dtypes.float32, [])), ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", structure.SparseTensorStructure(dtypes.float32, [32, None]), structure.SparseTensorStructure(dtypes.float32, [None])), ("SparseTensorUnknown", structure.SparseTensorStructure(dtypes.float32, [None, 4]), structure.SparseTensorStructure(dtypes.float32, [4])), ("RaggedTensor", structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2), structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)), ("RaggedTensorUnknown", structure.RaggedTensorStructure(dtypes.float32, [None, None, None], 2), structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)), ("Nested", { "a": structure.TensorStructure(dtypes.float32, [128]), "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]), structure.TensorStructure(dtypes.string, [None])) }, { "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) }), ) def testUnbatch(self, element_structure, expected_unbatched_structure): unbatched_structure = nest.map_structure( lambda component_spec: component_spec._unbatch(), element_structure) self.assertEqual(unbatched_structure, expected_unbatched_structure) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), lambda: constant_op.constant([1.0, 2.0])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), lambda: sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2])), ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]), lambda: ragged_factory_ops.constant([[1]])), ("Nest", lambda: (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), sparse_tensor.SparseTensor( indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), lambda: (constant_op.constant([1.0, 2.0]), sparse_tensor.SparseTensor( indices=[[0]], values=[13], dense_shape=[2]))), ) def testToBatchedTensorList(self, value_fn, element_0_fn): batched_value = value_fn() s = structure.type_spec_from_value(batched_value) batched_tensor_list = structure.to_batched_tensor_list( s, batched_value) # The batch dimension is 2 for all of the test cases. # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT # tensors in which we store sparse tensors. for t in batched_tensor_list: if t.dtype != dtypes.variant: self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) # Test that the 0th element from the unbatched tensor is equal to the # expected value. expected_element_0 = self.evaluate(element_0_fn()) unbatched_s = nest.map_structure( lambda component_spec: component_spec._unbatch(), s) actual_element_0 = structure.from_tensor_list( unbatched_s, [t[0] for t in batched_tensor_list]) for expected, actual in zip(nest.flatten(expected_element_0), nest.flatten(actual_element_0)): self.assertValuesEqual(expected, actual)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) def testAsFunctionWithMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, range(0, 10, 2)) def testAsFunctionWithMapInFlatMap(self): if not context.executing_eagerly(): self.skipTest("Only works executing eagerly") with ops.device("CPU"): original_dataset = dataset_ops.Dataset.range(5).flat_map( lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2)) fn = original_dataset._trace_variant_creation() variant = fn() revived_dataset = _RevivedDataset( variant, original_dataset._element_structure) self.assertDatasetProduces(revived_dataset, list(original_dataset)) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertLen(dataset_fn()._inputs(), num_inputs) @test_util.run_v1_only("deprecated API, no eager or V2 test coverage") def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEmpty(dataset_fn._inputs()) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testNoWarnings(self): with test.mock.patch.object(warnings, "warn") as mock_log: dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10)) dataset_fn(dataset_ops.Dataset.range(10)) self.assertEmpty(mock_log.call_args_list) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testFunctions(self): dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2) self.assertLen(dataset._functions(), 1) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset_ops.make_one_shot_iterator(dataset) self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage") def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with self.assertRaisesRegexp(ValueError, "must be from the same graph"): dataset = dataset.batch(2) @parameterized.named_parameters( ("Async", context.ASYNC), ("Sync", context.SYNC), ) def testDatasetEagerIteration(self, execution_mode): with context.eager_mode(), context.execution_mode(execution_mode): val = 0 dataset = dataset_ops.Dataset.range(10) for foo in dataset: self.assertEqual(val, foo.numpy()) val += 1 def testDatasetAsFunctionArgument(self): @def_function.function def _uses_dataset(d): accumulator = array_ops.zeros([], dtype=dtypes.int64) for value in d: accumulator += value return accumulator with ops.device("CPU"): first_dataset = dataset_ops.Dataset.range(10) self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset))) second_dataset = dataset_ops.Dataset.range(11) self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset))) first_concrete = _uses_dataset.get_concrete_function(first_dataset) # The dataset should not be a captured input self.assertEmpty(first_concrete.graph.captures) # The two datasets have the same structure and so should re-use a trace. self.assertIs(first_concrete, _uses_dataset.get_concrete_function(second_dataset)) # With a different structure we should use a different trace. self.assertIsNot( first_concrete, _uses_dataset.get_concrete_function( dataset_ops.Dataset.zip((first_dataset, second_dataset)))) def testLimitedRetracing(self): trace_count = [0] @def_function.function def f(ds): trace_count[0] += 1 counter = np.int64(0) for elem in ds: counter += elem return counter dataset = dataset_ops.Dataset.range(5) dataset2 = dataset_ops.Dataset.range(10) for _ in range(10): self.assertEqual(self.evaluate(f(dataset)), 10) self.assertEqual(self.evaluate(f(dataset2)), 45) self.assertEqual(trace_count[0], 1)
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def testAsSerializedGraph(self): dataset = dataset_ops.Dataset.range(10) graph = graph_pb2.GraphDef().FromString( self.evaluate(dataset._as_serialized_graph())) self.assertTrue(any([node.op != "RangeDataset" for node in graph.node])) @staticmethod def make_apply_fn(dataset): def apply_fn(dataset): def _apply_fn(dataset): return dataset.cache() return dataset.apply(_apply_fn) return apply_fn @staticmethod def make_gen(): def gen(): yield 42 return gen @staticmethod def make_interleave_fn(dataset, num_parallel_calls=None): def interleave_fn(dataset): return dataset.interleave( lambda x: dataset_ops.Dataset.range(0), cycle_length=2, num_parallel_calls=num_parallel_calls) return interleave_fn @parameterized.named_parameters( ("FixedLengthRecord", lambda: readers.FixedLengthRecordDataset("", 42)), ("FromGenerator", lambda: dataset_ops.Dataset.from_generator( DatasetTest.make_gen(), dtypes.int32), 1), ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])), ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])), ("Range", lambda: dataset_ops.Dataset.range(10)), ("TextLine", lambda: readers.TextLineDataset("")), ("TFRecord", lambda: readers.TFRecordDataset(""), 1), ) def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0): self.assertEqual(num_inputs, len(dataset_fn()._inputs())) def testDatasetComplexSourceInputs(self): dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices( sparse_tensor.SparseTensor( indices=np.array([[0, 0], [1, 0], [2, 0]]), values=np.array([0, 0, 0]), dense_shape=np.array([3, 1]))) self.assertEqual(0, len(dataset_fn._inputs())) @parameterized.named_parameters( ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)), ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)), ("Filter", lambda x: x.filter(lambda x: True), lambda: dataset_ops.Dataset.range(0)), ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)), lambda: dataset_ops.Dataset.range(0)), ("Map", lambda x: x.map(lambda x: x), lambda: dataset_ops.Dataset.range(0)), ("PaddedBatch", lambda x: x.padded_batch(10, []), lambda: dataset_ops.Dataset.range(0)), ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2), lambda: dataset_ops.Dataset.range(0)), ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)), ("Shuffle", lambda x: x.shuffle(10), lambda: dataset_ops.Dataset.range(0)), ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)), ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)), ("Window", lambda x: x.window(10), lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn): input_dataset = input_dataset_fn() self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) def testUnaryTransformationInputsApply(self): input_dataset = dataset_ops.Dataset.range(0) dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0)) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2], lambda: dataset_ops.Dataset.range(0)), ("Interleave", [lambda: dataset_ops.Dataset.range(0), None], lambda: dataset_ops.Dataset.range(0)), ) def testUnaryTransformationInputsWithInterleaveFn( self, interleave_fn_args, input_dataset_fn): input_dataset = input_dataset_fn() dataset_fn = self.make_interleave_fn(*interleave_fn_args) self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs()) @parameterized.named_parameters( ("Concatenate", lambda x, y: x.concatenate(y), lambda: dataset_ops.Dataset.range(0), lambda: dataset_ops.Dataset.range(1))) def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn): input1 = input1_fn() input2 = input2_fn() self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs()) @parameterized.named_parameters( ("ZipOne", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0))), ("ZipNest", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))), ("ZipTuple", dataset_ops.Dataset.zip, lambda: (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))), ) def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn): input_datasets = input_datasets_fn() self.assertEqual( nest.flatten(input_datasets), dataset_fn(input_datasets)._inputs()) def testCollectInputs(self): ds1 = dataset_ops.Dataset.range(0) ds2 = ds1.concatenate(ds1) ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2)) inputs = [] queue = [ds3] while queue: ds = queue[0] queue = queue[1:] queue.extend(ds._inputs()) inputs.append(ds) self.assertEqual(5, inputs.count(ds1)) self.assertEqual(2, inputs.count(ds2)) self.assertEqual(1, inputs.count(ds3)) # TODO(b/119882922): use-after-free bug in eager mode. # pylint: disable=g-long-lambda @parameterized.named_parameters( ("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", lambda: sparse_tensor.SparseTensor( indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32), dense_shape=[1]), structure.SparseTensorStructure(dtypes.int32, [1])), ("Nest", lambda: { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))}, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.TensorStructure(dtypes.string, [1]), structure.TensorStructure(dtypes.string, []))})), ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])), dataset_ops.DatasetStructure( structure.TensorStructure(dtypes.int32, []))), ("Optional", lambda: optional_ops.Optional.from_value(37.0), optional_ops.OptionalStructure( structure.TensorStructure(dtypes.float32, []))), ) def testSkipEagerDatasetStructure(self, tf_value_fn, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn()) dataset_structure = structure.Structure.from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure) # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing # the element structure. self.assertTrue(expected_element_structure.is_compatible_with( dataset_structure._element_structure)) self.assertTrue(dataset_structure._element_structure.is_compatible_with( expected_element_structure)) self.assertEqual([dtypes.variant], dataset_structure._flat_types) self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value_fn() if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces( round_trip_dataset, [self.evaluate(tf_value_fn())], requires_initialization=True) @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShot(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorOneShotSimple(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): with test.mock.patch.object(logging, "warning") as mock_log: _ = dataset.make_one_shot_iterator() self.assertRegexpMatches( str(mock_log.call_args), "Please ensure that all datasets in the " "pipeline are created in the same graph as the iterator.") @test_util.run_deprecated_v1 def testSkipEagerSameGraphErrorInitializable(self): dataset = dataset_ops.Dataset.range(10) with ops.Graph().as_default(): dataset = dataset.batch(2) with self.assertRaisesRegexp(ValueError, "must be from the same graph"): _ = dataset.make_initializable_iterator()
class StructureTest(test.TestCase, parameterized.TestCase): # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they # will be executed before the (eager- or graph-mode) test environment has been # set up. # pylint: disable=g-long-lambda,protected-access @parameterized.parameters( (lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3] ]), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, structure.NestedStructure, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None])) def testFlatStructure(self, value_fn, expected_structure, expected_types, expected_shapes): value = value_fn() s = structure.Structure.from_value(value) self.assertIsInstance(s, expected_structure) self.assertEqual(expected_types, s._flat_types) for expected, actual in zip(expected_shapes, s._flat_shapes): self.assertTrue(actual.is_compatible_with(expected)) self.assertTrue( tensor_shape.as_shape(expected).is_compatible_with(actual)) @parameterized.parameters( (lambda: constant_op.constant(37.0), lambda: [ constant_op.constant(38.0), array_ops.placeholder(dtypes.float32), variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), array_ops.sparse_placeholder(dtype=dtypes.int32), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) ], lambda: [ constant_op.constant(37, shape=[4, 5]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None, None]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) ]), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6]) }], lambda: [{ "a": constant_op.constant(15.0), "b": constant_op.constant([4, 5, 6, 7]) }, { "a": constant_op.constant(15), "b": constant_op.constant([4, 5, 6]) }, { "a": constant_op.constant(15), "b": sparse_tensor.SparseTensor( indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), ) def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.Structure.from_value(original_value) for compatible_value in compatible_values: self.assertTrue( s.is_compatible_with( structure.Structure.from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( s.is_compatible_with( structure.Structure.from_value(incompatible_value))) @parameterized.parameters( (lambda: constant_op.constant(37.0), ), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ), (lambda: { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) }, ), (lambda: { "a": constant_op.constant(37.0), "b": (sparse_tensor. SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) }, ), ) def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) before = self.evaluate(value) after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) for b, a in zip(flat_before, flat_after): if isinstance(b, sparse_tensor.SparseTensorValue): self.assertAllEqual(b.indices, a.indices) self.assertAllEqual(b.values, a.values) self.assertAllEqual(b.dense_shape, a.dense_shape) else: self.assertAllEqual(b, a) # pylint: enable=g-long-lambda def testIncompatibleStructure(self): # Define three mutually incompatible values/structures, and assert that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_tensor = constant_op.constant(42.0) s_tensor = structure.Structure.from_value(value_tensor) flat_tensor = s_tensor._to_tensor_list(value_tensor) value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor) flat_sparse_tensor = s_sparse_tensor._to_tensor_list( value_sparse_tensor) value_nest = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_nest = structure.Structure.from_value(value_nest) flat_nest = s_nest._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, r"SparseTensor.* is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Value \{.*\} is not convertible to a tensor with " r"dtype.*float32.* and shape \(\)"): s_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_tensor) with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"): s_sparse_tensor._to_tensor_list(value_nest) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure.*TensorStructure"): s_nest._to_tensor_list(value_sparse_tensor) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"): s_tensor._from_tensor_list(flat_sparse_tensor) with self.assertRaisesRegexp( ValueError, "TensorStructure corresponds to a single tf.Tensor."): s_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_sparse_tensor._from_tensor_list(flat_nest) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_tensor) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 1."): s_nest._from_tensor_list(flat_sparse_tensor) def testIncompatibleNestedStructure(self): # Define three mutually incompatible nested values/structures, and assert # that: # 1. Using one structure to flatten a value with an incompatible structure # fails. # 2. Using one structure to restructre a flattened value with an # incompatible structure fails. value_0 = { "a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3]) } s_0 = structure.Structure.from_value(value_0) flat_s_0 = s_0._to_tensor_list(value_0) # `value_1` has compatible nested structure with `value_0`, but different # classes. value_1 = { "a": constant_op.constant(37.0), "b": sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]) } s_1 = structure.Structure.from_value(value_1) flat_s_1 = s_1._to_tensor_list(value_1) # `value_2` has incompatible nested structure with `value_0` and `value_1`. value_2 = { "a": constant_op.constant(37.0), "b": (sparse_tensor.SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]), sparse_tensor.SparseTensor(indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) } s_2 = structure.Structure.from_value(value_2) flat_s_2 = s_2._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "SparseTensor.* not compatible with the nested structure " ".*TensorStructure"): s_0._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) with self.assertRaisesRegexp( ValueError, "Tensor.* not compatible with the nested structure " ".*SparseTensorStructure"): s_1._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "SparseTensor.*SparseTensor.* not compatible with the " "nested structure .*TensorStructure"): s_0._to_tensor_list(value_2) # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp # needs to account for "a" coming before or after "b". It might be worth # adding a deterministic repr for these error messages (among other # improvements). with self.assertRaisesRegexp( ValueError, "Tensor.*Tensor.* not compatible with the nested structure " ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_0) with self.assertRaisesRegexp( ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* " "not compatible with the nested structure .*" "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|" "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)" ): s_2._to_tensor_list(value_1) with self.assertRaisesRegexp( ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"): s_0._from_tensor_list(flat_s_1) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_0._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "SparseTensorStructure corresponds to a single tf.variant " "vector of length 3."): s_1._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 2 flat values in NestedStructure but got 3."): s_1._from_tensor_list(flat_s_2) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_0) with self.assertRaisesRegexp( ValueError, "Expected 3 flat values in NestedStructure but got 2."): s_2._from_tensor_list(flat_s_1) @parameterized.named_parameters( ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor, structure.TensorStructure(dtypes.float32, [])), ("SparseTensor", dtypes.int32, tensor_shape.matrix( 2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), ("Nest", { "a": dtypes.float32, "b": (dtypes.int32, dtypes.string) }, { "a": tensor_shape.scalar(), "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar()) }, { "a": ops.Tensor, "b": (sparse_tensor.SparseTensor, ops.Tensor) }, structure.NestedStructure({ "a": structure.TensorStructure(dtypes.float32, []), "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]), structure.TensorStructure(dtypes.string, [])) })), ) def testFromLegacyStructure(self, output_types, output_shapes, output_classes, expected_structure): actual_structure = structure.Structure._from_legacy_structure( output_types, output_shapes, output_classes) self.assertTrue( expected_structure.is_compatible_with(actual_structure)) self.assertTrue( actual_structure.is_compatible_with(expected_structure))