示例#1
0
  def assertDatasetsEqual(self, dataset1, dataset2):
    """Checks that datasets are equal. Supports both graph and eager mode."""
    self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
        dataset_ops.get_structure(dataset2)))
    self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
        dataset_ops.get_structure(dataset1)))
    flattened_types = nest.flatten(
        dataset_ops.get_legacy_output_types(dataset1))

    next1 = self.getNext(dataset1)
    next2 = self.getNext(dataset2)

    while True:
      try:
        op1 = self.evaluate(next1())
      except errors.OutOfRangeError:
        with self.assertRaises(errors.OutOfRangeError):
          self.evaluate(next2())
        break
      op2 = self.evaluate(next2())

      op1 = nest.flatten(op1)
      op2 = nest.flatten(op2)
      assert len(op1) == len(op2)
      for i in range(len(op1)):
        if sparse_tensor.is_sparse(op1[i]):
          self.assertSparseValuesEqual(op1[i], op2[i])
        elif flattened_types[i] == dtypes.string:
          self.assertAllEqual(op1[i], op2[i])
        else:
          self.assertAllClose(op1[i], op2[i])
  def testCopySparseTensorsToDeviceWithPrefetch(self):

    def make_tensor(i):
      return sparse_tensor.SparseTensorValue(
          indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])

    host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)

    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int64, next_element.dtype)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        actual = self.evaluate(next_element)
        self.assertAllEqual([i], actual.values)
        self.assertAllEqual([[0, 0]], actual.indices)
        self.assertAllEqual([2, 2], actual.dense_shape)
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
  def testCopyToDeviceWithReInitAndPrefetch(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_initializable_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      self.evaluate(iterator.initializer)
      for i in range(5):
        self.assertEqual(i, self.evaluate(next_element))
      self.evaluate(iterator.initializer)
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
示例#4
0
  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
        dataset_ops.get_legacy_output_shapes(dataset_3))
    next_element = feedable_iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(dataset_3).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))
    self.assertTrue(dataset_ops.get_structure(dataset_4).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})
示例#5
0
    def __init__(self, input_dataset, window_size, window_shift,
                 window_stride):
        """See `sliding_window_batch` for details."""
        self._input_dataset = input_dataset
        self._window_size = ops.convert_to_tensor(window_size,
                                                  dtype=dtypes.int64,
                                                  name="window_stride")
        self._window_stride = ops.convert_to_tensor(window_stride,
                                                    dtype=dtypes.int64,
                                                    name="window_stride")
        self._window_shift = ops.convert_to_tensor(window_shift,
                                                   dtype=dtypes.int64,
                                                   name="window_shift")

        input_structure = dataset_ops.get_structure(input_dataset)
        self._element_spec = nest.map_structure(
            lambda component_spec: component_spec._batch(None),
            input_structure)  # pylint: disable=protected-access
        variant_tensor = ged_ops.experimental_sliding_window_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            window_size=self._window_size,
            window_shift=self._window_shift,
            window_stride=self._window_stride,
            **self._flat_structure)
        super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
示例#6
0
  def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
    """Creates a _RebatchDataset.

    Args:
      input_dataset: `Dataset` to rebatch.
      batch_sizes: A `tf.int64` scalar or vector, representing the size of
        batches to produce. If this argument is a vector, these values are
        cycled through in order.
      drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
        whether the last batch should be dropped in the case it has fewer than
        `batch_sizes[cycle_index] elements; the default behavior is not to drop
        the smaller batch.
    """
    self._input_dataset = input_dataset
    self._batch_sizes = ops.convert_to_tensor(
        batch_sizes, dtype=dtypes.int64, name="batch_sizes")
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
    new_batch_dim = self._compute_static_batch_dim()

    # pylint: disable=protected-access
    self._element_spec = nest.map_structure(
        lambda ts: ts._unbatch()._batch(new_batch_dim),
        dataset_ops.get_structure(input_dataset))
    # pylint: enable=protected-access

    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
    variant_tensor = ged_ops.rebatch_dataset_v2(
        input_dataset._variant_tensor,  # pylint: disable=protected-access
        batch_sizes=batch_sizes,
        drop_remainder=drop_remainder,
        **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#7
0
    def _may_form_partial_batches(self, desired_batch_size):
        """Returns whether this dataset may form partial batches."""
        if tensor_util.constant_value(self._drop_remainder):
            return False

        def get_batch_dim(type_spec):
            shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(shape, tensor_shape.TensorShape):
                return None
            if shape.rank is None:
                return None
            if len(shape) < 1:
                raise ValueError(
                    "Expected a dataset whose elements have rank >= 1 "
                    "but found a dataset whose elements are scalars. "
                    "You can fix the issue by adding the `batch` "
                    "transformation to the dataset.")
            return shape.dims[0].value

        input_batch_dims = [
            get_batch_dim(ts) for ts in nest.flatten(
                dataset_ops.get_structure(self._input_dataset))
        ]
        known_input_batch_dims = [d for d in input_batch_dims if d is not None]

        if not known_input_batch_dims:
            return True

        known_input_batch_dims = np.asarray(known_input_batch_dims)
        if not np.all(known_input_batch_dims == known_input_batch_dims[0]):
            raise ValueError(
                "Batch dimensions of input dataset are not compatible.")

        return known_input_batch_dims[0] % desired_batch_size != 0
示例#8
0
    def write(self, dataset):
        """Writes a dataset to a TFRecord file.

    An operation that writes the content of the specified dataset to the file
    specified in the constructor.

    If the file exists, it will be overwritten.

    Args:
      dataset: a `tf.data.Dataset` whose elements are to be written to a file

    Returns:
      In graph mode, this returns an operation which when executed performs the
      write. In eager mode, the write is performed by the method itself and
      there is no return value.

    Raises
      TypeError: if `dataset` is not a `tf.data.Dataset`.
      TypeError: if the elements produced by the dataset are not scalar strings.
    """
        if not isinstance(dataset, dataset_ops.DatasetV2):
            raise TypeError(
                f"Invalid `dataset.` Expected a `tf.data.Dataset` object but got "
                f"{type(dataset)}.")
        if not dataset_ops.get_structure(dataset).is_compatible_with(
                tensor_spec.TensorSpec([], dtypes.string)):
            raise TypeError(
                f"Invalid `dataset`. Expected a`dataset` that produces scalar "
                f"`tf.string` elements, but got a dataset which produces elements "
                f"with shapes {dataset_ops.get_legacy_output_shapes(dataset)} and "
                f"types {dataset_ops.get_legacy_output_types(dataset)}.")
        # pylint: disable=protected-access
        dataset = dataset._apply_debug_options()
        return gen_experimental_dataset_ops.dataset_to_tf_record(
            dataset._variant_tensor, self._filename, self._compression_type)
示例#9
0
 def testRoundtripMap(self):
     dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
     variant = dataset_ops.to_variant(dataset)
     dataset = dataset_ops.from_variant(variant,
                                        dataset_ops.get_structure(dataset))
     self.assertDatasetProduces(dataset, [x * x for x in range(10)])
     self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
示例#10
0
 def testRoundtripRange(self):
     dataset = dataset_ops.Dataset.range(10)
     variant = dataset_ops.to_variant(dataset)
     dataset = dataset_ops.from_variant(variant,
                                        dataset_ops.get_structure(dataset))
     self.assertDatasetProduces(dataset, range(10))
     self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
示例#11
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(output_shapes):
            """Recalculates the output_shapes after dividing it by num_replicas."""
            if len(output_shapes) < 1:
                raise ValueError(
                    "Input shape should have at least one dimension. "
                    "Perhaps your input dataset is not batched?")
            output_dims = [d.value for d in output_shapes.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(
                type_spec._to_legacy_output_shapes())
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#12
0
 def testRoundtripMap(self):
   dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
   variant = dataset_ops.to_variant(dataset)
   dataset = dataset_ops.from_variant(variant,
                                      dataset_ops.get_structure(dataset))
   self.assertDatasetProduces(dataset, [x * x for x in range(10)])
   self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
示例#13
0
 def testRoundtripRange(self):
   dataset = dataset_ops.Dataset.range(10)
   variant = dataset_ops.to_variant(dataset)
   dataset = dataset_ops.from_variant(variant,
                                      dataset_ops.get_structure(dataset))
   self.assertDatasetProduces(dataset, range(10))
   self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
示例#14
0
    def write(self, dataset):
        """Returns a `tf.Operation` to write a dataset to a file.

    Args:
      dataset: a `tf.data.Dataset` whose elements are to be written to a file

    Returns:
      A `tf.Operation` that, when run, writes contents of `dataset` to a file.
    """
        if not isinstance(dataset, dataset_ops.DatasetV2):
            raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
        if not dataset_ops.get_structure(dataset).is_compatible_with(
                structure.TensorStructure(dtypes.string, [])):
            raise TypeError(
                "`dataset` must produce scalar `DT_STRING` tensors whereas it "
                "produces shape {0} and types {1}".format(
                    dataset_ops.get_legacy_output_shapes(dataset),
                    dataset_ops.get_legacy_output_types(dataset)))
        if compat.forward_compatible(2019, 8, 3):
            return gen_experimental_dataset_ops.dataset_to_tf_record(
                dataset._variant_tensor, self._filename,
                self._compression_type)  # pylint: disable=protected-access
        else:
            return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
                dataset._variant_tensor, self._filename,
                self._compression_type)  # pylint: disable=protected-access
示例#15
0
    def _testDatasetSpec(self, tf_value, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.TensorShape([])],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value)],
                                       requires_initialization=True)
示例#16
0
  def _may_form_partial_batches(self, desired_batch_size):
    """Returns whether this dataset may form partial batches."""
    if tensor_util.constant_value(self._drop_remainder):
      return False

    def get_batch_dim(type_spec):
      shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
      if not isinstance(shape, tensor_shape.TensorShape):
        return None
      if shape.rank is None:
        return None
      if len(shape) < 1:
        raise ValueError("Invalid `batch_sizes`. Expected dataset with "
                         "rank of >= 1 but found a dataset with "
                         "scalar elements. Fix the issue by adding the `batch` "
                         "transformation to the dataset.")
      return shape.dims[0].value

    input_batch_dims = [
        get_batch_dim(ts)
        for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset))
    ]
    known_input_batch_dims = [d for d in input_batch_dims if d is not None]

    if not known_input_batch_dims:
      return True

    known_input_batch_dims = np.asarray(known_input_batch_dims)
    if not np.all(known_input_batch_dims == known_input_batch_dims[0]):
      raise ValueError(
          f"Invalid `input_dataset.` The batch dimension of component 0 "
          f"is {known_input_batch_dims[0]}, while the batch dimension "
          f"of component i is {known_input_batch_dims}.")

    return known_input_batch_dims[0] % desired_batch_size != 0
示例#17
0
def compute_batch_size(dataset):
  """An operation that returns the batch size of the dataset.

  This op tries to infer the batch size statically by walking up the dataset
  tree from the final dataset node and returning the batch size of the first
  batching dataset (such as from .batch() and .padded_batch()) that it
  encounters. This differs from using the `element_spec` of a dataset in that it
  does not account for partial batches.

  This operation may fail if it encounters contradictory batch sizes (for
  example, if the dataset is created by zipping together two datasets with
  different batch sizes), if there are no explicit batching transformations, or
  if there are operations downstream from the batching transformation that may
  modify its batch size. In these cases, it returns a -1.

  Args:
    dataset: A `tf.data.Dataset` object.

  Returns:
    A `tf.int64` Tensor representing the batch size of the dataset sans partial
    batches. If this cannot be inferred statically, the value of this tensor
    will be -1.
  """

  def get_static_batch_dim(type_spec):
    try:
      output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
    except NotImplementedError:
      return None
    if not isinstance(output_shape, tensor_shape.TensorShape):
      return None
    if output_shape.rank is None:
      return None
    return output_shape.dims[0].value

  batch_dims = [
      get_static_batch_dim(type_spec)
      for type_spec in nest.flatten(dataset_ops.get_structure(dataset))
  ]

  if all(d is not None for d in batch_dims):

    if all(d == batch_dims[0] for d in batch_dims):
      # If all batch dimensions are known and equal, return that directly.
      batch_dim = batch_dims[0]
    else:
      # If all batch dimensions are known but not all equal, return -1.
      batch_dim = -1

    return constant_op.constant(
        batch_dim, dtype=dtypes.int64, name="static_batch_size")

  # If any batch dimensions are unknown, use compute_batch_size op.
  return ged_ops.compute_batch_size(dataset._variant_tensor)  # pylint: disable=protected-access
示例#18
0
    def __init__(self, input_dataset, num_replicas):
        """Creates a _LegacyRebatchDataset.

    Args:
      input_dataset: `Dataset` to rebatch.
      num_replicas: A `tf.int64` scalar, representing the number of sub-batches
        to split each batch from `input_dataset` into.
    """
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Invalid `input_dataset`. Expected a dataset whose elements "
                    "have rank >= 1 but found a dataset whose elements are scalars. "
                    "Fix the issue by adding the `batch` transformation to the "
                    "dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))

        # auto_shard rewrite assumes that there's normalize_to_dense before
        # rebatch_dataset.
        # LINT.IfChange
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        # LINT.ThenChange(//tensorflow/core/grappler/optimizers/data/auto_shard.cc)
        super(_LegacyRebatchDataset, self).__init__(input_dataset,
                                                    variant_tensor)
  def testCopyToDeviceInt32(self):
    host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/cpu:1"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int32, next_element.dtype)
    self.assertEqual((4,), next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
示例#20
0
    def testCopyToDeviceInt32(self):
        host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int32, next_element.dtype)
        self.assertEqual((4, ), next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)
  def testPrefetchDictToDevice(self):
    host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device("/cpu:1"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(structure.are_compatible(
        dataset_ops.get_structure(host_dataset),
        dataset_ops.get_structure(device_dataset)))

    self.assertEqual(dtypes.int64, next_element["a"].dtype)
    self.assertEqual([], next_element["a"].shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual({"a": i}, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
  def testCopyDictToDeviceWithPrefetch(self):
    host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int64, next_element["a"].dtype)
    self.assertEqual([], next_element["a"].shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual({"a": i}, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
  def testPrefetchToSameDevice(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device(
            "/job:localhost/replica:0/task:0/device:CPU:0"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(structure.are_compatible(
        dataset_ops.get_structure(host_dataset),
        dataset_ops.get_structure(device_dataset)))

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
  def testCopyToDevice(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.copy_to_device("/cpu:1"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
  def testPrefetchToSameDevice(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device(
            "/job:localhost/replica:0/task:0/device:CPU:0"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(device_dataset)))
    self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
        dataset_ops.get_structure(iterator)))

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
    def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                              expected_output_classes, expected_output_types,
                              expected_output_shapes):
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
示例#27
0
def dataset_extractor(dataset,
                      num_elements,
                      filename,
                      feed_name,
                      print_stats=True):
    """Allows the user to extract tensors from a `tf.data.Dataset`.

    Args:
      dataset: An instance of `tf.data.Dataset` to extract elements from.
      num_elements: The number of elements to extract from the dataset.
      filename: Where to save the extracted elements to.
      feed_name: Name of the infeed the dataset is associated with.
      print_stats: Whether to print progress messages to the
        console.

    Note:
      All the tuple elements will be saved in the same binary file.

    Returns:
      The operation that will save the elements of the infeed to file.

    Raises:
      TypeError: if `dataset` is not an instance of `tf.data.Dataset`.
      ValueError: if `num_elements` is less than 1.
    """
    if num_elements < 1:
        return ValueError("Expected `num_elements` to be at least 1.")

    if not isinstance(dataset, dataset_ops.DatasetV2):
        return TypeError("Expected `dataset` argument to be of type "
                         "`tf.data.Dataset`, but got %s "
                         "instead." % (str(dataset)))

    try:
        dataset_variant = dataset._variant_tensor  # pylint: disable=protected-access
    except TypeError:
        dataset_variant = dataset._as_variant_tensor  # pylint: disable=protected-access

    struct = dataset_ops.get_structure(dataset)

    return gen_dataset_exporters.dataset_extractor(dataset_variant,
                                                   print_stats, num_elements,
                                                   filename, feed_name,
                                                   **dataset._flat_structure)  # pylint: disable=protected-access
示例#28
0
  def __init__(self, input_dataset, window_size, window_shift, window_stride):
    """See `sliding_window_batch` for details."""
    self._input_dataset = input_dataset
    self._window_size = ops.convert_to_tensor(
        window_size, dtype=dtypes.int64, name="window_stride")
    self._window_stride = ops.convert_to_tensor(
        window_stride, dtype=dtypes.int64, name="window_stride")
    self._window_shift = ops.convert_to_tensor(
        window_shift, dtype=dtypes.int64, name="window_shift")

    input_structure = dataset_ops.get_structure(input_dataset)
    self._structure = input_structure._batch(None)  # pylint: disable=protected-access
    variant_tensor = ged_ops.experimental_sliding_window_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        window_size=self._window_size,
        window_shift=self._window_shift,
        window_stride=self._window_stride,
        **dataset_ops.flat_structure(self))
    super(_SlideDataset, self).__init__(input_dataset, variant_tensor)
示例#29
0
  def write(self, dataset):
    """Returns a `tf.Operation` to write a dataset to a file.

    Args:
      dataset: a `tf.data.Dataset` whose elements are to be written to a file

    Returns:
      A `tf.Operation` that, when run, writes contents of `dataset` to a file.
    """
    if not isinstance(dataset, dataset_ops.DatasetV2):
      raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
    if not dataset_ops.get_structure(dataset).is_compatible_with(
        structure.TensorStructure(dtypes.string, [])):
      raise TypeError(
          "`dataset` must produce scalar `DT_STRING` tensors whereas it "
          "produces shape {0} and types {1}".format(
              dataset_ops.get_legacy_output_shapes(dataset),
              dataset_ops.get_legacy_output_types(dataset)))
    return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
        dataset._variant_tensor, self._filename, self._compression_type)  # pylint: disable=protected-access
示例#30
0
    def __init__(self, input_dataset, num_replicas, use_fallback=True):
        def recalculate_batch_size(type_spec):
            """Recalculates the output_shape after dividing it by num_replicas."""
            output_shape = type_spec._to_legacy_output_shapes()  # pylint: disable=protected-access
            if not isinstance(output_shape, tensor_shape.TensorShape):
                return None

            # If the output shape is unknown, we set the batch dimension to unknown.
            if output_shape.rank is None:
                return None

            if len(output_shape) < 1:
                raise ValueError(
                    "Expected a dataset whose elements have rank >= 1 "
                    "but found a dataset whose elements are scalars. "
                    "You can fix the issue by adding the `batch` "
                    "transformation to the dataset.")
            output_dims = [d.value for d in output_shape.dims]

            if output_dims[
                    0] is not None and output_dims[0] % num_replicas == 0:
                return output_dims[0] // num_replicas

            # Set the batch dimension to unknown. If the global batch size does not
            # divide num_replicas evenly, the minibatches may have different sizes.
            return None

        def rebatch(type_spec):
            # pylint: disable=protected-access
            batch_size = recalculate_batch_size(type_spec)
            return type_spec._unbatch()._batch(batch_size)
            # pylint: enable=protected-access

        self._element_spec = nest.map_structure(
            rebatch, dataset_ops.get_structure(input_dataset))
        input_dataset = dataset_ops.normalize_to_dense(input_dataset)
        variant_tensor = ged_ops.rebatch_dataset(
            input_dataset._variant_tensor,  # pylint: disable=protected-access
            num_replicas=num_replicas,
            **self._flat_structure)
        super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
示例#31
0
  def __init__(self, input_dataset):
    """See `unbatch()` for more details."""
    input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
    flat_shapes = nest.flatten(input_shapes)
    if any(s.ndims == 0 for s in flat_shapes):
      raise ValueError("Cannot unbatch an input with scalar components.")
    known_batch_dim = tensor_shape.Dimension(None)
    for s in flat_shapes:
      try:
        known_batch_dim = known_batch_dim.merge_with(s[0])
      except ValueError:
        raise ValueError("Cannot unbatch an input whose components have "
                         "different batch sizes.")
    self._input_dataset = input_dataset

    self._structure = dataset_ops.get_structure(input_dataset)._unbatch()  # pylint: disable=protected-access

    variant_tensor = ged_ops.experimental_unbatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        **dataset_ops.flat_structure(self))
    super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
示例#32
0
def choose_from_datasets_v2(datasets, choice_dataset):
  """Creates a dataset that deterministically chooses elements from `datasets`.

  For example, given the following datasets:

  ```python
  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
              tf.data.Dataset.from_tensors("bar").repeat(),
              tf.data.Dataset.from_tensors("baz").repeat()]

  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
  choice_dataset = tf.data.Dataset.range(3).repeat(3)

  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
  ```

  The elements of `result` will be:

  ```
  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
  ```

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
      `0` and `len(datasets) - 1`.

  Returns:
    A dataset that interleaves elements from `datasets` according to the values
    of `choice_dataset`.

  Raises:
    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
      type.
  """
  if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
      structure.TensorStructure(dtypes.int64, [])):
    raise TypeError("`choice_dataset` must be a dataset of scalar "
                    "`tf.int64` tensors.")
  return _DirectedInterleaveDataset(choice_dataset, datasets)
示例#33
0
def choose_from_datasets_v2(datasets, choice_dataset):
    """Creates a dataset that deterministically chooses elements from `datasets`.

  For example, given the following datasets:

  ```python
  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
              tf.data.Dataset.from_tensors("bar").repeat(),
              tf.data.Dataset.from_tensors("baz").repeat()]

  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
  choice_dataset = tf.data.Dataset.range(3).repeat(3)

  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
  ```

  The elements of `result` will be:

  ```
  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
  ```

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
      `0` and `len(datasets) - 1`.

  Returns:
    A dataset that interleaves elements from `datasets` according to the values
    of `choice_dataset`.

  Raises:
    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
      type.
  """
    if not dataset_ops.get_structure(choice_dataset).is_compatible_with(
            structure.TensorStructure(dtypes.int64, [])):
        raise TypeError("`choice_dataset` must be a dataset of scalar "
                        "`tf.int64` tensors.")
    return _DirectedInterleaveDataset(choice_dataset, datasets)
示例#34
0
    def __init__(self, input_dataset):
        """See `unbatch()` for more details."""
        input_shapes = dataset_ops.get_legacy_output_shapes(input_dataset)
        flat_shapes = nest.flatten(input_shapes)
        if any(s.ndims == 0 for s in flat_shapes):
            raise ValueError("Cannot unbatch an input with scalar components.")
        known_batch_dim = tensor_shape.Dimension(None)
        for s in flat_shapes:
            try:
                known_batch_dim = known_batch_dim.merge_with(s[0])
            except ValueError:
                raise ValueError(
                    "Cannot unbatch an input whose components have "
                    "different batch sizes.")
        self._input_dataset = input_dataset

        self._structure = dataset_ops.get_structure(input_dataset)._unbatch()  # pylint: disable=protected-access

        variant_tensor = ged_ops.experimental_unbatch_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            **dataset_ops.flat_structure(self))
        super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
示例#35
0
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
示例#36
0
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
示例#37
0
    def testIteratorStringHandleFuture(self):
        with forward_compat.forward_compatibility_horizon(2018, 8, 4):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            dataset_4 = dataset_ops.Dataset.from_tensor_slices(
                [10, 20, 30, 40])

            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

            handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            feedable_iterator = iterator_ops.Iterator.from_string_handle(
                handle_placeholder,
                dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            next_element = feedable_iterator.get_next()

            self.assertTrue(
                structure.are_compatible(
                    dataset_ops.get_structure(dataset_3),
                    dataset_ops.get_structure(feedable_iterator)))

            with self.cached_session() as sess:
                iterator_3_handle = sess.run(iterator_3.string_handle())
                iterator_4_handle = sess.run(iterator_4.string_handle())

                self.assertEqual(
                    10,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    1,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    20,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    2,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    30,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    3,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    40,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_3_handle})
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_4_handle})
 def _element_structure(self):
   return dataset_ops.get_structure(self._dataset)
示例#39
0
    def __init__(self,
                 dataset,
                 feed_name,
                 device_ordinal=0,
                 replication_factor=1,
                 data_to_prefetch=1,
                 prefetch_depth=None):
        """Creates an IPUInfeedQueue object.

    Args:
       dataset: a `tf.data.Dataset` object, all transformations e.g. `shuffle`,
         `repeat`, `batch` must be applied prior to passing in to this function.
         This dataset can no longer be used after creating this queue.
       feed_name: the name of the infeed queue.  This must be unique between
         all IPUInfeedQueues and IPUOutfeedQueues.
       device_ordinal: ordinal of the IPU device on which this queue will be
         used. By default the queue will be used on "/device/IPU:0".
       replication_factor: the number of replicated graphs this infeed will be
         used in.
       data_to_prefetch: the amount of data to prefetch.
         Defaults to 1, no prefetch.
         If set to non-1 (and non-0) each time we sync with the CPU we will
         return this number of dataset values rather than 1. This must not go
         over the size of the dataset if it is not repeating, and will increment
         the infeed by this amount each time so using the infeed in multiple
         programs or loops should take into account that if `data_to_prefetch`
         is not a factor of the previous iterations
         then the next loop/program will not be starting at the iteration it
         otherwise would be.
         This will obviously increase the memory usage from having more batches
         live at a given point but should give a speed up by having to make
         fewer round trips to host memory. It may be that larger number of
         batches should be prefetched at once in order to see any benefit as the
         lookup itself has some overhead from internal copies.
        prefetch_depth: the number of elements poplar will prefetch.
          The depth of the poplar datastream buffer size which may be prefetched
          before being read by the device. By default the prefetch_depth size is
          automatically determined. Increasing the size of the prefetch_depth
          allows for prefetching of multiple entries, increasing the probability
          there will be a valid entry in the buffer for the device to read
          before falling back to synchronously fetching the next entry.

    Raises:
      ValueError: if all dimensions of shapes of dataset.output_shapes are not
        fully defined. tf.data.batch function must be called with
        `drop_remainder=True` to ensure that batch size is constant.

    """

        for output_shape in dataset._flat_structure["output_shapes"]:
            if isinstance(output_shape, list) or isinstance(
                    output_shape, tuple):
                raise ValueError(
                    "Nested list/tuple input shapes are not supported")
            if not output_shape.is_fully_defined():
                raise ValueError(
                    """Output shape {} is not fully defined. If using \
tf.Dataset.batch, set `drop_remainder=True`.""".format(output_shape))
        if prefetch_depth is None:
            prefetch_depth = 1
        if prefetch_depth <= 0:
            raise ValueError(
                "prefetch_depth must be greater than zero, but it is {}".
                format(prefetch_depth))
        if prefetch_depth > 255:
            raise ValueError(
                "prefetch_depth must be less than 256, but it is {}".format(
                    prefetch_depth))

        with ops.device('/device:CPU:0'):
            self._replication_factor = replication_factor
            self._dataset = dataset
            self._structure = dataset_ops.get_structure(self._dataset)
            self._flat_structure = dataset._flat_structure
            self._device_ordinal = device_ordinal
            self._prefetch_depth = prefetch_depth

            # We use max to clamp 0/1 to the same value.
            self._io_batch_size = max(1, data_to_prefetch)

            # Batch the dataset to take replication and prefetch into account.

            if self._io_batch_size != 1:
                self._dataset = self._dataset.batch(self._io_batch_size,
                                                    drop_remainder=True)

            if self._replication_factor != 1:
                self._dataset = self._dataset.batch(self._replication_factor,
                                                    drop_remainder=True)

            # Apply the dataset and take ownership.
            self._dataset = self._dataset._apply_options()

            # ID used for differentiating between datasets.
            self._id = str(feed_name)

            try:
                ds_variant = self._dataset._variant_tensor  # pylint: disable=protected-access
            except TypeError:
                ds_variant = self._dataset._as_variant_tensor  # pylint: disable=protected-access

            if not context.executing_eagerly():
                # For Estimators, the graph can be frozen before the estimator calls
                # the initilizer or deleter methods.  So we need to create the
                # initialize and delete operations early.  For eager execution in
                # TF2, the operations execute eagerly, so they don't exist in any
                # frozen graph.
                with ops.colocate_with(ds_variant):
                    self._init_op = gen_pop_datastream_ops.ipu_create_dataset_iterator(
                        input_dataset=ds_variant,
                        feed_id=self._id,
                        replication_factor=self._replication_factor,
                        device_ordinal=self._device_ordinal,
                        **self._dataset._flat_structure)  # pylint: disable=protected-access

                self._deleter = gen_pop_datastream_ops.ipu_delete_dataset_iterator(
                    feed_id=self._id, device_ordinal=self._device_ordinal)

        self._dequeued = False
        self._initialized = False
示例#40
0
    def __init__(self,
                 input_dataset,
                 functions,
                 ratio_numerator=1,
                 ratio_denominator=1,
                 num_elements_per_branch=None):
        """Chooses the fastest of some dataset functions.

    Given dataset functions that take input_dataset as input and output
    another dataset, produces elements as quickly as the fastest of these
    output datasets. Note that datasets in the dataset functions are assumed
    to be stateless, and the iterators created by the functions' output datasets
    will, given the same input elements, all produce the same output elements.
    Datasets in the functions are also expected to iterate over the input
    dataset at most once. The violation of these conditions may lead to
    undefined behavior.

    For example:
    ```python
    dataset = tf.data.Dataset.range(100)
    dataset = _ChooseFastestDataset(
        dataset,
        [
            lambda ds: ds.map(lambda x: tf.reshape(x, [1])).batch(10),
            lambda ds: ds.batch(10).map(lambda x: tf.reshape(x, [10, 1]))
        ],
        ratio=10,
        num_elements_per_branch=10
    )
    ```
    The resulting dataset will produce elements equivalent to
    `tf.data.Dataset.range(100).map(lambda x: tf.reshape(x, [1])).batch(10)`, or
    `tf.data.Dataset.range(100).batch(10).map(lambda x: tf.reshape(x, [10, 1]))`

    Note that the first `num_elements_per_branch` iterations may be slower due
    to the
    overhead of dynamically picking the fastest dataset. Namely, for these
    iterations, the dataset will produce elements from any of branches to
    determine which input is the fastest. For all subsequent iterations, that
    input will be used.

    Args:
      input_dataset: A `Dataset` that can be used as input to `functions`.
      functions: A list of callables, each of which takes a `Dataset` as input
        and returns a `Dataset`.
      ratio_numerator: The numerator in the ratio of input elements consumed to
        output elements produced for each function. This should be the same for
        all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_numerator should be 10.
      ratio_denominator: The denominator in the ratio of input elements consumed
        to output elements produced for each function. This should be the same
        for all functions. For example, if the function is
        `lambda ds: ds.batch(10)`, the ratio is 10:1, i.e. the input dataset
          must produce 10 elements for every element of the output dataset. In
          this case, ratio_denominator should be 1.
      num_elements_per_branch: The number of elements to get from each branch
        before deciding which dataset is fastest. In the first len(functions) *
        num_elements_per_branch iterations, the dataset will call from one of
        the branches, and update its knowledge of which input is the fastest.
        Note that (num_elements_per_branch * ratio) is expected to be an
        integer.

    Returns:
      A `Dataset` that has the same elements the inputs.
    """
        input_structure = dataset_ops.DatasetStructure(
            dataset_ops.get_structure(input_dataset))
        self._funcs = [
            dataset_ops.StructuredFunctionWrapper(
                f, "ChooseFastestV2", input_structure=input_structure)
            for f in functions
        ]
        self._structure = self._funcs[0].output_structure._element_structure  # pylint: disable=protected-access

        self._captured_arguments = []
        for f in self._funcs:
            self._captured_arguments.extend(f.function.captured_inputs)
        self._capture_lengths = [
            len(f.function.captured_inputs) for f in self._funcs
        ]

        if ratio_numerator <= 0 or ratio_denominator <= 0:
            raise ValueError("ratio must be positive.")

        if num_elements_per_branch is None:
            # Pick a sensible default based on `ratio_denominator`
            num_elements_per_branch = 10 * ratio_denominator

        variant_tensor = (
            gen_experimental_dataset_ops.choose_fastest_branch_dataset(
                input_dataset._variant_tensor,  # pylint: disable=protected-access
                ratio_numerator=ratio_numerator,
                ratio_denominator=ratio_denominator,
                other_arguments=self._captured_arguments,
                num_elements_per_branch=num_elements_per_branch,
                branches=[f.function for f in self._funcs],
                other_arguments_lengths=self._capture_lengths,
                **self._flat_structure))
        super(_ChooseFastestBranchDataset,
              self).__init__(input_dataset, variant_tensor)
示例#41
0
 def _element_structure(self):
     return dataset_ops.get_structure(self._dataset)