def assertDatasetsEqual(self, dataset1, dataset2): """Checks that datasets are equal. Supports both graph and eager mode.""" self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with( dataset_ops.get_structure(dataset2))) self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with( dataset_ops.get_structure(dataset1))) flattened_types = nest.flatten( dataset_ops.get_legacy_output_types(dataset1)) next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) while True: try: op1 = self.evaluate(next1()) except errors.OutOfRangeError: with self.assertRaises(errors.OutOfRangeError): self.evaluate(next2()) break op2 = self.evaluate(next2()) op1 = nest.flatten(op1) op2 = nest.flatten(op2) assert len(op1) == len(op2) for i in range(len(op1)): if sparse_tensor.is_sparse(op1[i]): self.assertSparseValuesEqual(op1[i], op2[i]) elif flattened_types[i] == dtypes.string: self.assertAllEqual(op1[i], op2[i]) else: self.assertAllClose(op1[i], op2[i])
def testCopySparseTensorsToDeviceWithPrefetch(self): def make_tensor(i): return sparse_tensor.SparseTensorValue( indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2]) host_dataset = dataset_ops.Dataset.range(10).map(make_tensor) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int64, next_element.dtype) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): actual = self.evaluate(next_element) self.assertAllEqual([i], actual.values) self.assertAllEqual([[0, 0]], actual.indices) self.assertAllEqual([2, 2], actual.dense_shape) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testCopyToDeviceWithReInitAndPrefetch(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) with ops.device("/cpu:1"): iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): self.evaluate(iterator.initializer) for i in range(5): self.assertEqual(i, self.evaluate(next_element)) self.evaluate(iterator.initializer) for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testIteratorStringHandle(self): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue(dataset_ops.get_structure(dataset_3).is_compatible_with( dataset_ops.get_structure(feedable_iterator))) self.assertTrue(dataset_ops.get_structure(dataset_4).is_compatible_with( dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual(10, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(1, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(20, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(2, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(30, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual(3, sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual(40, sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run( next_element, feed_dict={handle_placeholder: iterator_4_handle})
def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor(window_size, dtype=dtypes.int64, name="window_stride") self._window_stride = ops.convert_to_tensor(window_stride, dtype=dtypes.int64, name="window_stride") self._window_shift = ops.convert_to_tensor(window_shift, dtype=dtypes.int64, name="window_shift") input_structure = dataset_ops.get_structure(input_dataset) self._element_spec = nest.map_structure( lambda component_spec: component_spec._batch(None), input_structure) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_sliding_window_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **self._flat_structure) super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset, batch_sizes, drop_remainder=False): """Creates a _RebatchDataset. Args: input_dataset: `Dataset` to rebatch. batch_sizes: A `tf.int64` scalar or vector, representing the size of batches to produce. If this argument is a vector, these values are cycled through in order. drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing whether the last batch should be dropped in the case it has fewer than `batch_sizes[cycle_index] elements; the default behavior is not to drop the smaller batch. """ self._input_dataset = input_dataset self._batch_sizes = ops.convert_to_tensor( batch_sizes, dtype=dtypes.int64, name="batch_sizes") self._drop_remainder = ops.convert_to_tensor( drop_remainder, dtype=dtypes.bool, name="drop_remainder") new_batch_dim = self._compute_static_batch_dim() # pylint: disable=protected-access self._element_spec = nest.map_structure( lambda ts: ts._unbatch()._batch(new_batch_dim), dataset_ops.get_structure(input_dataset)) # pylint: enable=protected-access input_dataset = dataset_ops.normalize_to_dense(input_dataset) variant_tensor = ged_ops.rebatch_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access batch_sizes=batch_sizes, drop_remainder=drop_remainder, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def _may_form_partial_batches(self, desired_batch_size): """Returns whether this dataset may form partial batches.""" if tensor_util.constant_value(self._drop_remainder): return False def get_batch_dim(type_spec): shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(shape, tensor_shape.TensorShape): return None if shape.rank is None: return None if len(shape) < 1: raise ValueError( "Expected a dataset whose elements have rank >= 1 " "but found a dataset whose elements are scalars. " "You can fix the issue by adding the `batch` " "transformation to the dataset.") return shape.dims[0].value input_batch_dims = [ get_batch_dim(ts) for ts in nest.flatten( dataset_ops.get_structure(self._input_dataset)) ] known_input_batch_dims = [d for d in input_batch_dims if d is not None] if not known_input_batch_dims: return True known_input_batch_dims = np.asarray(known_input_batch_dims) if not np.all(known_input_batch_dims == known_input_batch_dims[0]): raise ValueError( "Batch dimensions of input dataset are not compatible.") return known_input_batch_dims[0] % desired_batch_size != 0
def write(self, dataset): """Writes a dataset to a TFRecord file. An operation that writes the content of the specified dataset to the file specified in the constructor. If the file exists, it will be overwritten. Args: dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: In graph mode, this returns an operation which when executed performs the write. In eager mode, the write is performed by the method itself and there is no return value. Raises TypeError: if `dataset` is not a `tf.data.Dataset`. TypeError: if the elements produced by the dataset are not scalar strings. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError( f"Invalid `dataset.` Expected a `tf.data.Dataset` object but got " f"{type(dataset)}.") if not dataset_ops.get_structure(dataset).is_compatible_with( tensor_spec.TensorSpec([], dtypes.string)): raise TypeError( f"Invalid `dataset`. Expected a`dataset` that produces scalar " f"`tf.string` elements, but got a dataset which produces elements " f"with shapes {dataset_ops.get_legacy_output_shapes(dataset)} and " f"types {dataset_ops.get_legacy_output_types(dataset)}.") # pylint: disable=protected-access dataset = dataset._apply_debug_options() return gen_experimental_dataset_ops.dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type)
def testRoundtripMap(self): dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x) variant = dataset_ops.to_variant(dataset) dataset = dataset_ops.from_variant(variant, dataset_ops.get_structure(dataset)) self.assertDatasetProduces(dataset, [x * x for x in range(10)]) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
def testRoundtripRange(self): dataset = dataset_ops.Dataset.range(10) variant = dataset_ops.to_variant(dataset) dataset = dataset_ops.from_variant(variant, dataset_ops.get_structure(dataset)) self.assertDatasetProduces(dataset, range(10)) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
def __init__(self, input_dataset, num_replicas, use_fallback=True): def recalculate_batch_size(output_shapes): """Recalculates the output_shapes after dividing it by num_replicas.""" if len(output_shapes) < 1: raise ValueError( "Input shape should have at least one dimension. " "Perhaps your input dataset is not batched?") output_dims = [d.value for d in output_shapes.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size( type_spec._to_legacy_output_shapes()) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def testRoundtripMap(self): dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x) variant = dataset_ops.to_variant(dataset) dataset = dataset_ops.from_variant(variant, dataset_ops.get_structure(dataset)) self.assertDatasetProduces(dataset, [x * x for x in range(10)]) self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
def write(self, dataset): """Returns a `tf.Operation` to write a dataset to a file. Args: dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: A `tf.Operation` that, when run, writes contents of `dataset` to a file. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") if not dataset_ops.get_structure(dataset).is_compatible_with( structure.TensorStructure(dtypes.string, [])): raise TypeError( "`dataset` must produce scalar `DT_STRING` tensors whereas it " "produces shape {0} and types {1}".format( dataset_ops.get_legacy_output_shapes(dataset), dataset_ops.get_legacy_output_types(dataset))) if compat.forward_compatible(2019, 8, 3): return gen_experimental_dataset_ops.dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access else: return gen_experimental_dataset_ops.experimental_dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
def _testDatasetSpec(self, tf_value, expected_element_structure): dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value) dataset_structure = structure.type_spec_from_value(dataset) self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(dataset), expected_element_structure)) self.assertEqual([dtypes.variant], structure.get_flat_tensor_types(dataset_structure)) self.assertEqual([tensor_shape.TensorShape([])], structure.get_flat_tensor_shapes(dataset_structure)) # Assert that the `Dataset` survives a round-trip via _from_tensor_list() # and _to_tensor_list(). round_trip_dataset = dataset_structure._from_tensor_list( dataset_structure._to_tensor_list(dataset)) value = tf_value if isinstance(value, dataset_ops.Dataset): self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x)) elif isinstance(value, optional_ops.Optional): self.assertDatasetProduces( round_trip_dataset.map(lambda opt: opt.get_value()), [self.evaluate(value.get_value())], requires_initialization=True) else: self.assertDatasetProduces(round_trip_dataset, [self.evaluate(tf_value)], requires_initialization=True)
def _may_form_partial_batches(self, desired_batch_size): """Returns whether this dataset may form partial batches.""" if tensor_util.constant_value(self._drop_remainder): return False def get_batch_dim(type_spec): shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(shape, tensor_shape.TensorShape): return None if shape.rank is None: return None if len(shape) < 1: raise ValueError("Invalid `batch_sizes`. Expected dataset with " "rank of >= 1 but found a dataset with " "scalar elements. Fix the issue by adding the `batch` " "transformation to the dataset.") return shape.dims[0].value input_batch_dims = [ get_batch_dim(ts) for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset)) ] known_input_batch_dims = [d for d in input_batch_dims if d is not None] if not known_input_batch_dims: return True known_input_batch_dims = np.asarray(known_input_batch_dims) if not np.all(known_input_batch_dims == known_input_batch_dims[0]): raise ValueError( f"Invalid `input_dataset.` The batch dimension of component 0 " f"is {known_input_batch_dims[0]}, while the batch dimension " f"of component i is {known_input_batch_dims}.") return known_input_batch_dims[0] % desired_batch_size != 0
def compute_batch_size(dataset): """An operation that returns the batch size of the dataset. This op tries to infer the batch size statically by walking up the dataset tree from the final dataset node and returning the batch size of the first batching dataset (such as from .batch() and .padded_batch()) that it encounters. This differs from using the `element_spec` of a dataset in that it does not account for partial batches. This operation may fail if it encounters contradictory batch sizes (for example, if the dataset is created by zipping together two datasets with different batch sizes), if there are no explicit batching transformations, or if there are operations downstream from the batching transformation that may modify its batch size. In these cases, it returns a -1. Args: dataset: A `tf.data.Dataset` object. Returns: A `tf.int64` Tensor representing the batch size of the dataset sans partial batches. If this cannot be inferred statically, the value of this tensor will be -1. """ def get_static_batch_dim(type_spec): try: output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access except NotImplementedError: return None if not isinstance(output_shape, tensor_shape.TensorShape): return None if output_shape.rank is None: return None return output_shape.dims[0].value batch_dims = [ get_static_batch_dim(type_spec) for type_spec in nest.flatten(dataset_ops.get_structure(dataset)) ] if all(d is not None for d in batch_dims): if all(d == batch_dims[0] for d in batch_dims): # If all batch dimensions are known and equal, return that directly. batch_dim = batch_dims[0] else: # If all batch dimensions are known but not all equal, return -1. batch_dim = -1 return constant_op.constant( batch_dim, dtype=dtypes.int64, name="static_batch_size") # If any batch dimensions are unknown, use compute_batch_size op. return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access
def __init__(self, input_dataset, num_replicas): """Creates a _LegacyRebatchDataset. Args: input_dataset: `Dataset` to rebatch. num_replicas: A `tf.int64` scalar, representing the number of sub-batches to split each batch from `input_dataset` into. """ def recalculate_batch_size(type_spec): """Recalculates the output_shape after dividing it by num_replicas.""" output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(output_shape, tensor_shape.TensorShape): return None # If the output shape is unknown, we set the batch dimension to unknown. if output_shape.rank is None: return None if len(output_shape) < 1: raise ValueError( "Invalid `input_dataset`. Expected a dataset whose elements " "have rank >= 1 but found a dataset whose elements are scalars. " "Fix the issue by adding the `batch` transformation to the " "dataset.") output_dims = [d.value for d in output_shape.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size(type_spec) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) # auto_shard rewrite assumes that there's normalize_to_dense before # rebatch_dataset. # LINT.IfChange input_dataset = dataset_ops.normalize_to_dense(input_dataset) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc) super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor)
def testCopyToDeviceInt32(self): host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int32, next_element.dtype) self.assertEqual((4,), next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testCopyToDeviceInt32(self): host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3]) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int32, next_element.dtype) self.assertEqual((4, ), next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testPrefetchDictToDevice(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element["a"].dtype) self.assertEqual([], next_element["a"].shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual({"a": i}, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testCopyDictToDeviceWithPrefetch(self): host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x}) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")).prefetch(1) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int64, next_element["a"].dtype) self.assertEqual([], next_element["a"].shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual({"a": i}, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testPrefetchToSameDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device( "/job:localhost/replica:0/task:0/device:CPU:0")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(structure.are_compatible( dataset_ops.get_structure(host_dataset), dataset_ops.get_structure(device_dataset))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testCopyToDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.copy_to_device("/cpu:1")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testPrefetchToSameDevice(self): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device( "/job:localhost/replica:0/task:0/device:CPU:0")) with ops.device("/cpu:1"): iterator = dataset_ops.make_one_shot_iterator(device_dataset) next_element = iterator.get_next() self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(device_dataset))) self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with( dataset_ops.get_structure(iterator))) self.assertEqual(dtypes.int64, next_element.dtype) self.assertEqual([], next_element.shape) worker_config = config_pb2.ConfigProto(device_count={"CPU": 2}) with self.test_session(config=worker_config): for i in range(10): self.assertEqual(i, self.evaluate(next_element)) with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element)
def testIteratorStructure(self, tf_value_fn, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def dataset_extractor(dataset, num_elements, filename, feed_name, print_stats=True): """Allows the user to extract tensors from a `tf.data.Dataset`. Args: dataset: An instance of `tf.data.Dataset` to extract elements from. num_elements: The number of elements to extract from the dataset. filename: Where to save the extracted elements to. feed_name: Name of the infeed the dataset is associated with. print_stats: Whether to print progress messages to the console. Note: All the tuple elements will be saved in the same binary file. Returns: The operation that will save the elements of the infeed to file. Raises: TypeError: if `dataset` is not an instance of `tf.data.Dataset`. ValueError: if `num_elements` is less than 1. """ if num_elements < 1: return ValueError("Expected `num_elements` to be at least 1.") if not isinstance(dataset, dataset_ops.DatasetV2): return TypeError("Expected `dataset` argument to be of type " "`tf.data.Dataset`, but got %s " "instead." % (str(dataset))) try: dataset_variant = dataset._variant_tensor # pylint: disable=protected-access except TypeError: dataset_variant = dataset._as_variant_tensor # pylint: disable=protected-access struct = dataset_ops.get_structure(dataset) return gen_dataset_exporters.dataset_extractor(dataset_variant, print_stats, num_elements, filename, feed_name, **dataset._flat_structure) # pylint: disable=protected-access
def __init__(self, input_dataset, window_size, window_shift, window_stride): """See `sliding_window_batch` for details.""" self._input_dataset = input_dataset self._window_size = ops.convert_to_tensor( window_size, dtype=dtypes.int64, name="window_stride") self._window_stride = ops.convert_to_tensor( window_stride, dtype=dtypes.int64, name="window_stride") self._window_shift = ops.convert_to_tensor( window_shift, dtype=dtypes.int64, name="window_shift") input_structure = dataset_ops.get_structure(input_dataset) self._structure = input_structure._batch(None) # pylint: disable=protected-access variant_tensor = ged_ops.experimental_sliding_window_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access window_size=self._window_size, window_shift=self._window_shift, window_stride=self._window_stride, **dataset_ops.flat_structure(self)) super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
def write(self, dataset): """Returns a `tf.Operation` to write a dataset to a file. Args: dataset: a `tf.data.Dataset` whose elements are to be written to a file Returns: A `tf.Operation` that, when run, writes contents of `dataset` to a file. """ if not isinstance(dataset, dataset_ops.DatasetV2): raise TypeError("`dataset` must be a `tf.data.Dataset` object.") if not dataset_ops.get_structure(dataset).is_compatible_with( structure.TensorStructure(dtypes.string, [])): raise TypeError( "`dataset` must produce scalar `DT_STRING` tensors whereas it " "produces shape {0} and types {1}".format( dataset_ops.get_legacy_output_shapes(dataset), dataset_ops.get_legacy_output_types(dataset))) return gen_experimental_dataset_ops.experimental_dataset_to_tf_record( dataset._variant_tensor, self._filename, self._compression_type) # pylint: disable=protected-access
def __init__(self, input_dataset, num_replicas, use_fallback=True): def recalculate_batch_size(type_spec): """Recalculates the output_shape after dividing it by num_replicas.""" output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access if not isinstance(output_shape, tensor_shape.TensorShape): return None # If the output shape is unknown, we set the batch dimension to unknown. if output_shape.rank is None: return None if len(output_shape) < 1: raise ValueError( "Expected a dataset whose elements have rank >= 1 " "but found a dataset whose elements are scalars. " "You can fix the issue by adding the `batch` " "transformation to the dataset.") output_dims = [d.value for d in output_shape.dims] if output_dims[ 0] is not None and output_dims[0] % num_replicas == 0: return output_dims[0] // num_replicas # Set the batch dimension to unknown. If the global batch size does not # divide num_replicas evenly, the minibatches may have different sizes. return None def rebatch(type_spec): # pylint: disable=protected-access batch_size = recalculate_batch_size(type_spec) return type_spec._unbatch()._batch(batch_size) # pylint: enable=protected-access self._element_spec = nest.map_structure( rebatch, dataset_ops.get_structure(input_dataset)) input_dataset = dataset_ops.normalize_to_dense(input_dataset) variant_tensor = ged_ops.rebatch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access num_replicas=num_replicas, **self._flat_structure) super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_shapes = nest.flatten(input_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError("Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def choose_from_datasets_v2(datasets, choice_dataset): """Creates a dataset that deterministically chooses elements from `datasets`. For example, given the following datasets: ```python datasets = [tf.data.Dataset.from_tensors("foo").repeat(), tf.data.Dataset.from_tensors("bar").repeat(), tf.data.Dataset.from_tensors("baz").repeat()] # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. choice_dataset = tf.data.Dataset.range(3).repeat(3) result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset) ``` The elements of `result` will be: ``` "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" ``` Args: datasets: A list of `tf.data.Dataset` objects with compatible structure. choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0` and `len(datasets) - 1`. Returns: A dataset that interleaves elements from `datasets` according to the values of `choice_dataset`. Raises: TypeError: If the `datasets` or `choice_dataset` arguments have the wrong type. """ if not dataset_ops.get_structure(choice_dataset).is_compatible_with( structure.TensorStructure(dtypes.int64, [])): raise TypeError("`choice_dataset` must be a dataset of scalar " "`tf.int64` tensors.") return _DirectedInterleaveDataset(choice_dataset, datasets)
def __init__(self, input_dataset): """See `unbatch()` for more details.""" input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset) flat_shapes = nest.flatten(input_shapes) if any(s.ndims == 0 for s in flat_shapes): raise ValueError("Cannot unbatch an input with scalar components.") known_batch_dim = tensor_shape.Dimension(None) for s in flat_shapes: try: known_batch_dim = known_batch_dim.merge_with(s[0]) except ValueError: raise ValueError( "Cannot unbatch an input whose components have " "different batch sizes.") self._input_dataset = input_dataset self._structure = dataset_ops.get_structure(input_dataset)._unbatch() # pylint: disable=protected-access variant_tensor = ged_ops.experimental_unbatch_dataset( self._input_dataset._variant_tensor, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
def testSparseTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return sparse_tensor.SparseTensor(indices=[[0]], values=constant_op.constant( [0], dtype=dtypes.int32), dense_shape=[1]) tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def testNestedTensorIteratorStructure(self, expected_element_structure, expected_output_classes, expected_output_types, expected_output_shapes): def tf_value_fn(): return { "a": constant_op.constant(37.0), "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar")) } tf_value = tf_value_fn() iterator = dataset_ops.make_one_shot_iterator( dataset_ops.Dataset.from_tensors(tf_value)) self.assertTrue( structure.are_compatible(dataset_ops.get_structure(iterator), expected_element_structure)) self.assertEqual(expected_output_classes, dataset_ops.get_legacy_output_classes(iterator)) self.assertEqual(expected_output_types, dataset_ops.get_legacy_output_types(iterator)) self.assertEqual(expected_output_shapes, dataset_ops.get_legacy_output_shapes(iterator))
def testIteratorStringHandleFuture(self): with forward_compat.forward_compatibility_horizon(2018, 8, 4): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) dataset_4 = dataset_ops.Dataset.from_tensor_slices( [10, 20, 30, 40]) iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3) iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4) handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3), dataset_ops.get_legacy_output_shapes(dataset_3)) next_element = feedable_iterator.get_next() self.assertTrue( structure.are_compatible( dataset_ops.get_structure(dataset_3), dataset_ops.get_structure(feedable_iterator))) with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) self.assertEqual( 10, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 1, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 20, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 2, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 30, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) self.assertEqual( 3, sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle})) self.assertEqual( 40, sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_3_handle}) with self.assertRaises(errors.OutOfRangeError): sess.run(next_element, feed_dict={handle_placeholder: iterator_4_handle})
def _element_structure(self): return dataset_ops.get_structure(self._dataset)
def __init__(self, dataset, feed_name, device_ordinal=0, replication_factor=1, data_to_prefetch=1, prefetch_depth=None): """Creates an IPUInfeedQueue object. Args: dataset: a `tf.data.Dataset` object, all transformations e.g. `shuffle`, `repeat`, `batch` must be applied prior to passing in to this function. This dataset can no longer be used after creating this queue. feed_name: the name of the infeed queue. This must be unique between all IPUInfeedQueues and IPUOutfeedQueues. device_ordinal: ordinal of the IPU device on which this queue will be used. By default the queue will be used on "/device/IPU:0". replication_factor: the number of replicated graphs this infeed will be used in. data_to_prefetch: the amount of data to prefetch. Defaults to 1, no prefetch. If set to non-1 (and non-0) each time we sync with the CPU we will return this number of dataset values rather than 1. This must not go over the size of the dataset if it is not repeating, and will increment the infeed by this amount each time so using the infeed in multiple programs or loops should take into account that if `data_to_prefetch` is not a factor of the previous iterations then the next loop/program will not be starting at the iteration it otherwise would be. This will obviously increase the memory usage from having more batches live at a given point but should give a speed up by having to make fewer round trips to host memory. It may be that larger number of batches should be prefetched at once in order to see any benefit as the lookup itself has some overhead from internal copies. prefetch_depth: the number of elements poplar will prefetch. The depth of the poplar datastream buffer size which may be prefetched before being read by the device. By default the prefetch_depth size is automatically determined. Increasing the size of the prefetch_depth allows for prefetching of multiple entries, increasing the probability there will be a valid entry in the buffer for the device to read before falling back to synchronously fetching the next entry. Raises: ValueError: if all dimensions of shapes of dataset.output_shapes are not fully defined. tf.data.batch function must be called with `drop_remainder=True` to ensure that batch size is constant. """ for output_shape in dataset._flat_structure["output_shapes"]: if isinstance(output_shape, list) or isinstance( output_shape, tuple): raise ValueError( "Nested list/tuple input shapes are not supported") if not output_shape.is_fully_defined(): raise ValueError( """Output shape {} is not fully defined. If using \ tf.Dataset.batch, set `drop_remainder=True`.""".format(output_shape)) if prefetch_depth is None: prefetch_depth = 1 if prefetch_depth <= 0: raise ValueError( "prefetch_depth must be greater than zero, but it is {}". format(prefetch_depth)) if prefetch_depth > 255: raise ValueError( "prefetch_depth must be less than 256, but it is {}".format( prefetch_depth)) with ops.device('/device:CPU:0'): self._replication_factor = replication_factor self._dataset = dataset self._structure = dataset_ops.get_structure(self._dataset) self._flat_structure = dataset._flat_structure self._device_ordinal = device_ordinal self._prefetch_depth = prefetch_depth # We use max to clamp 0/1 to the same value. self._io_batch_size = max(1, data_to_prefetch) # Batch the dataset to take replication and prefetch into account. if self._io_batch_size != 1: self._dataset = self._dataset.batch(self._io_batch_size, drop_remainder=True) if self._replication_factor != 1: self._dataset = self._dataset.batch(self._replication_factor, drop_remainder=True) # Apply the dataset and take ownership. self._dataset = self._dataset._apply_options() # ID used for differentiating between datasets. self._id = str(feed_name) try: ds_variant = self._dataset._variant_tensor # pylint: disable=protected-access except TypeError: ds_variant = self._dataset._as_variant_tensor # pylint: disable=protected-access if not context.executing_eagerly(): # For Estimators, the graph can be frozen before the estimator calls # the initilizer or deleter methods. So we need to create the # initialize and delete operations early. For eager execution in # TF2, the operations execute eagerly, so they don't exist in any # frozen graph. with ops.colocate_with(ds_variant): self._init_op = gen_pop_datastream_ops.ipu_create_dataset_iterator( input_dataset=ds_variant, feed_id=self._id, replication_factor=self._replication_factor, device_ordinal=self._device_ordinal, **self._dataset._flat_structure) # pylint: disable=protected-access self._deleter = gen_pop_datastream_ops.ipu_delete_dataset_iterator( feed_id=self._id, device_ordinal=self._device_ordinal) self._dequeued = False self._initialized = False
def __init__(self, input_dataset, functions, ratio_numerator=1, ratio_denominator=1, num_elements_per_branch=None): """Chooses the fastest of some dataset functions. Given dataset functions that take input_dataset as input and output another dataset, produces elements as quickly as the fastest of these output datasets. Note that datasets in the dataset functions are assumed to be stateless, and the iterators created by the functions' output datasets will, given the same input elements, all produce the same output elements. Datasets in the functions are also expected to iterate over the input dataset at most once. The violation of these conditions may lead to undefined behavior. For example: ```python dataset = tf.data.Dataset.range(100) dataset = _ChooseFastestDataset( dataset, [ lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10), lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1])) ], ratio=10, num_elements_per_branch=10 ) ``` The resulting dataset will produce elements equivalent to `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))` Note that the first `num_elements_per_branch` iterations may be slower due to the overhead of dynamically picking the fastest dataset. Namely, for these iterations, the dataset will produce elements from any of branches to determine which input is the fastest. For all subsequent iterations, that input will be used. Args: input_dataset: A `Dataset` that can be used as input to `functions`. functions: A list of callables, each of which takes a `Dataset` as input and returns a `Dataset`. ratio_numerator: The numerator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_numerator should be 10. ratio_denominator: The denominator in the ratio of input elements consumed to output elements produced for each function. This should be the same for all functions. For example, if the function is `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset must produce 10 elements for every element of the output dataset. In this case, ratio_denominator should be 1. num_elements_per_branch: The number of elements to get from each branch before deciding which dataset is fastest. In the first len(functions) * num_elements_per_branch iterations, the dataset will call from one of the branches, and update its knowledge of which input is the fastest. Note that (num_elements_per_branch * ratio) is expected to be an integer. Returns: A `Dataset` that has the same elements the inputs. """ input_structure = dataset_ops.DatasetStructure( dataset_ops.get_structure(input_dataset)) self._funcs = [ dataset_ops.StructuredFunctionWrapper( f, "ChooseFastestV2", input_structure=input_structure) for f in functions ] self._structure = self._funcs[0].output_structure._element_structure # pylint: disable=protected-access self._captured_arguments = [] for f in self._funcs: self._captured_arguments.extend(f.function.captured_inputs) self._capture_lengths = [ len(f.function.captured_inputs) for f in self._funcs ] if ratio_numerator <= 0 or ratio_denominator <= 0: raise ValueError("ratio must be positive.") if num_elements_per_branch is None: # Pick a sensible default based on `ratio_denominator` num_elements_per_branch = 10 * ratio_denominator variant_tensor = ( gen_experimental_dataset_ops.choose_fastest_branch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access ratio_numerator=ratio_numerator, ratio_denominator=ratio_denominator, other_arguments=self._captured_arguments, num_elements_per_branch=num_elements_per_branch, branches=[f.function for f in self._funcs], other_arguments_lengths=self._capture_lengths, **self._flat_structure)) super(_ChooseFastestBranchDataset, self).__init__(input_dataset, variant_tensor)