def testIteratorGetNextAsOptional(self, np_value, tf_value_fn, gpu_compatible): if not gpu_compatible and test.is_gpu_available(): self.skipTest("Test case not yet supported on GPU.") ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3) if context.executing_eagerly(): iterator = dataset_ops.make_one_shot_iterator(ds) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.element_spec, structure.type_spec_from_value(tf_value_fn()))) self.assertTrue(next_elem.has_value()) self.assertValuesEqual(np_value, next_elem.get_value()) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): next_elem = iterator_ops.get_next_as_optional(iterator) self.assertFalse(self.evaluate(next_elem.has_value())) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(next_elem.get_value()) else: iterator = dataset_ops.make_initializable_iterator(ds) next_elem = iterator_ops.get_next_as_optional(iterator) self.assertIsInstance(next_elem, optional_ops.Optional) self.assertTrue( structure.are_compatible( next_elem.element_spec, structure.type_spec_from_value(tf_value_fn()))) # Before initializing the iterator, evaluating the optional fails with # a FailedPreconditionError. This is only relevant in graph mode. elem_has_value_t = next_elem.has_value() elem_value_t = next_elem.get_value() with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_has_value_t) with self.assertRaises(errors.FailedPreconditionError): self.evaluate(elem_value_t) # Now we initialize the iterator. self.evaluate(iterator.initializer) # For each element of the dataset, assert that the optional evaluates to # the expected value. for _ in range(3): elem_has_value, elem_value = self.evaluate( [elem_has_value_t, elem_value_t]) self.assertTrue(elem_has_value) self.assertValuesEqual(np_value, elem_value) # After exhausting the iterator, `next_elem.has_value()` will evaluate to # false, and attempting to get the value will fail. for _ in range(2): self.assertFalse(self.evaluate(elem_has_value_t)) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(elem_value_t)
def testIsCompatibleWithStructure(self, original_value_fn, compatible_values_fn, incompatible_values_fn): original_value = original_value_fn() compatible_values = compatible_values_fn() incompatible_values = incompatible_values_fn() s = structure.type_spec_from_value(original_value) for compatible_value in compatible_values: self.assertTrue( structure.are_compatible( s, structure.type_spec_from_value(compatible_value))) for incompatible_value in incompatible_values: self.assertFalse( structure.are_compatible( s, structure.type_spec_from_value(incompatible_value)))
def testCopySparseTensorsToDeviceWithPrefetch(self): def make_tensor(i): return sparse_tensor.SparseTensorValue(indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element.dtype) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): actual = self.evaluate(next_element) self.assertAllEqual([i], actual.values) self.assertAllEqual([[0, 0]], actual.indices) self.assertAllEqual([2, 2], actual.dense_shape) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testCopyToDeviceWithReInitAndPrefetch(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) with ops.device("/cpu:1"): iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): self.evaluate(iterator.initializer) for i in range(5): self.assertEqual(i, self.evaluate(next_element)) self.evaluate(iterator.initializer) for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def checkDatasetSpec(self, tf_value, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value) dataset_structure = structure.type_spec_from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(dataset), expected_element_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(dataset_structure)) self.assertEqual([tensor_shape.TensorShape([])], structure.get_flat_tensor_shapes(dataset_structure)) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value)], requires_initialization=True)
def assertDatasetsEqual(self, dataset1, dataset2): """Checks that datasets are equal. Supports both graph and eager mode.""" self.assertTrue( structure.are_compatible(dataset_ops.get_structure(dataset1), dataset_ops.get_structure(dataset2))) flattened_types = nest.flatten( dataset_ops.get_legacy_output_types(dataset1)) next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) while True: try: op1 = self.evaluate(next1()) except errors.OutOfRangeError: with self.assertRaises(errors.OutOfRangeError): self.evaluate(next2()) break op2 = self.evaluate(next2()) op1 = nest.flatten(op1) op2 = nest.flatten(op2) assert len(op1) == len(op2) for i in range(len(op1)): if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged( op1[i]): self.assertValuesEqual(op1[i], op2[i]) elif flattened_types[i] == dtypes.string: self.assertAllEqual(op1[i], op2[i]) else: self.assertAllClose(op1[i], op2[i])
def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})
def choose_from_datasets_v2(datasets, choice_dataset, stop_on_empty_dataset=False): """Creates a dataset that deterministically chooses elements from `datasets`. For example, given the following datasets: ```python datasets = [tf.data.Dataset.from_tensors("foo").repeat(), tf.data.Dataset.from_tensors("bar").repeat(), tf.data.Dataset.from_tensors("baz").repeat()] # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. choice_dataset = tf.data.Dataset.range(3).repeat(3) result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) ``` The elements of `result` will be: ``` "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" ``` Args: datasets: A non-empty list of `tf.data.Dataset` objects with compatible structure. choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. stop_on_empty_dataset: If `True`, selection stops if it encounters an empty dataset. If `False`, it skips empty datasets. It is recommended to set it to `True`. Otherwise, the selected elements start off as the user intends, but may change as input datasets become empty. This can be difficult to detect since the dataset starts off looking correct. Default to `False` for backward compatibility. Returns: A dataset that interleaves elements from `datasets` according to the values of `choice_dataset`. Raises: TypeError: If `datasets` or `choice_dataset` has the wrong type. ValueError: If `datasets` is empty. """ if not datasets: raise ValueError("`datasets` must be a non-empty list of datasets.") if choice_dataset is None or not structure.are_compatible( choice_dataset.element_spec, tensor_spec.TensorSpec([], dtypes.int64)): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") # pylint: disable=protected-access return dataset_ops._DirectedInterleaveDataset(choice_dataset, datasets, stop_on_empty_dataset)
def testOptionalStructure(self, tf_value_fn, expected_value_structure): tf_value = tf_value_fn() opt = optional_ops.Optional.from_value(tf_value) self.assertTrue( structure.are_compatible(opt.value_structure, expected_value_structure)) opt_structure = structure.type_spec_from_value(opt) self.assertIsInstance(opt_structure, optional_ops.OptionalStructure) self.assertTrue(structure.are_compatible(opt_structure, opt_structure)) self.assertTrue( structure.are_compatible(opt_structure._value_structure, expected_value_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(opt_structure)) self.assertEqual([tensor_shape.scalar()], structure.get_flat_tensor_shapes(opt_structure)) # All OptionalStructure objects are not compatible with a non-optional # value. non_optional_structure = structure.type_spec_from_value( constant_op.constant(42.0)) self.assertFalse( opt_structure.is_compatible_with(non_optional_structure)) # Assert that the optional survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_opt = opt_structure._from_tensor_list( opt_structure._to_tensor_list(opt)) if isinstance(tf_value, optional_ops.Optional): self._assertElementValueEqual( self.evaluate(tf_value.get_value()), self.evaluate(round_trip_opt.get_value().get_value())) else: self._assertElementValueEqual( self.evaluate(tf_value), self.evaluate(round_trip_opt.get_value()))
def testIteratorStructure(self, tf_value_fn, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def _assert_datasets_equal(self, ds1, ds2): # First lets assert the structure is the same. self.assertTrue( structure.are_compatible(ds1.element_spec, ds2.element_spec)) # Now create iterators on both and assert they produce the same values. it1 = dataset_ops.make_initializable_iterator(ds1) it2 = dataset_ops.make_initializable_iterator(ds2) get_next1 = it1.get_next() get_next2 = it2.get_next() with self.cached_session(): self.evaluate([it1.initializer, it2.initializer]) val1, val2 = self.evaluate([get_next1, get_next2]) self.assertEqual(val1, val2)
def choose_from_datasets_v2(datasets, choice_dataset): """Creates a dataset that deterministically chooses elements from `datasets`. For example, given the following datasets: ```python datasets = [tf.data.Dataset.from_tensors("foo").repeat(), tf.data.Dataset.from_tensors("bar").repeat(), tf.data.Dataset.from_tensors("baz").repeat()] # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. choice_dataset = tf.data.Dataset.range(3).repeat(3) result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) ``` The elements of `result` will be: ``` "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" ``` Args: datasets: A list of `tf.data.Dataset` objects with compatible structure. choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. Returns: A dataset that interleaves elements from `datasets` according to the values of `choice_dataset`. Raises: TypeError: If the `datasets` or `choice_dataset` arguments have the wrong type. """ if not structure.are_compatible( choice_dataset.element_spec, structure.TensorStructure(dtypes.int64, [])): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") return _DirectedInterleaveDataset(choice_dataset, datasets)
def testCopyToDeviceInt32(self): host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int32, next_element.dtype) self.assertEqual((4, ), next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element["a"].dtype) self.assertEqual([], next_element["a"].shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual({"a": i}, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testSparseTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def testPrefetchToSameDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device( "/job:localhost/replica:0/task:0/device:CPU:0")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testNestedTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) } tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def __init__(self, input_dataset, features, num_parallel_calls): self._input_dataset = input_dataset if not structure.are_compatible( input_dataset.element_spec, tensor_spec.TensorSpec([None], dtypes.string)): raise TypeError("Input dataset should be a dataset of vectors of strings") self._num_parallel_calls = num_parallel_calls # pylint: disable=protected-access self._features = parsing_ops._prepend_none_dimension(features) # sparse_keys and dense_keys come back sorted here. (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults, dense_shapes) = parsing_ops._features_to_raw_params( self._features, [ parsing_ops.VarLenFeature, parsing_ops.SparseFeature, parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature ]) # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature. (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes, dense_shape_as_shape) = parsing_ops._process_raw_parameters( None, dense_defaults, sparse_keys, sparse_types, dense_keys, dense_types, dense_shapes) # pylint: enable=protected-access self._sparse_keys = sparse_keys self._sparse_types = sparse_types self._dense_keys = dense_keys self._dense_defaults = dense_defaults_vec self._dense_shapes = dense_shapes self._dense_types = dense_types input_dataset_shape = dataset_ops.get_legacy_output_shapes( self._input_dataset) dense_output_shapes = [input_dataset_shape.concatenate(shape) for shape in dense_shape_as_shape] sparse_output_shapes = [input_dataset_shape.concatenate([None]) for _ in range(len(sparse_keys))] output_shapes = dict( zip(self._dense_keys + self._sparse_keys, dense_output_shapes + sparse_output_shapes)) output_types = dict( zip(self._dense_keys + self._sparse_keys, self._dense_types + self._sparse_types)) output_classes = dict( zip(self._dense_keys + self._sparse_keys, [ops.Tensor for _ in range(len(self._dense_defaults))] + [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys)) ])) self._element_spec = structure.convert_legacy_structure( output_types, output_shapes, output_classes) if compat.forward_compatible(2019, 8, 3): variant_tensor = ( gen_experimental_dataset_ops.parse_example_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._num_parallel_calls, self._dense_defaults, self._sparse_keys, self._dense_keys, self._sparse_types, self._dense_shapes, **self._flat_structure)) else: variant_tensor = ( gen_experimental_dataset_ops.experimental_parse_example_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._num_parallel_calls, self._dense_defaults, self._sparse_keys, self._dense_keys, self._sparse_types, self._dense_shapes, **self._flat_structure)) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, features, num_parallel_calls, deterministic): self._input_dataset = input_dataset if not structure.are_compatible( input_dataset.element_spec, tensor_spec.TensorSpec([None], dtypes.string)): raise TypeError( "Input dataset should be a dataset of vectors of strings") self._num_parallel_calls = num_parallel_calls if deterministic is None: self._deterministic = "default" elif deterministic: self._deterministic = "true" else: self._deterministic = "false" # pylint: disable=protected-access self._features = parsing_ops._prepend_none_dimension(features) # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature params = parsing_ops._ParseOpParams.from_features( self._features, [ parsing_ops.VarLenFeature, parsing_ops.SparseFeature, parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature ]) # pylint: enable=protected-access self._sparse_keys = params.sparse_keys self._sparse_types = params.sparse_types self._ragged_keys = params.ragged_keys self._ragged_value_types = params.ragged_value_types self._ragged_split_types = params.ragged_split_types self._dense_keys = params.dense_keys self._dense_defaults = params.dense_defaults_vec self._dense_shapes = params.dense_shapes_as_proto self._dense_types = params.dense_types input_dataset_shape = dataset_ops.get_legacy_output_shapes( self._input_dataset) self._element_spec = {} for (key, value_type) in zip(params.sparse_keys, params.sparse_types): self._element_spec[key] = sparse_tensor.SparseTensorSpec( input_dataset_shape.concatenate([None]), value_type) for (key, value_type, dense_shape) in zip(params.dense_keys, params.dense_types, params.dense_shapes): self._element_spec[key] = tensor_spec.TensorSpec( input_dataset_shape.concatenate(dense_shape), value_type) for (key, value_type, splits_type) in zip(params.ragged_keys, params.ragged_value_types, params.ragged_split_types): self._element_spec[key] = ragged_tensor.RaggedTensorSpec( input_dataset_shape.concatenate([None]), value_type, 1, splits_type) variant_tensor = ( gen_experimental_dataset_ops.parse_example_dataset_v2( self._input_dataset._variant_tensor, # pylint: disable=protected-access self._num_parallel_calls, self._dense_defaults, self._sparse_keys, self._dense_keys, self._sparse_types, self._dense_shapes, deterministic=self._deterministic, ragged_keys=self._ragged_keys, ragged_value_types=self._ragged_value_types, ragged_split_types=self._ragged_split_types, **self._flat_structure)) super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)