Пример #1
0
    def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                      gpu_compatible):
        if not gpu_compatible and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

        if context.executing_eagerly():
            iterator = dataset_ops.make_one_shot_iterator(ds)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertIsInstance(next_elem, optional_ops.Optional)
                self.assertTrue(
                    structure.are_compatible(
                        next_elem.element_spec,
                        structure.type_spec_from_value(tf_value_fn())))
                self.assertTrue(next_elem.has_value())
                self.assertValuesEqual(np_value, next_elem.get_value())
            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertFalse(self.evaluate(next_elem.has_value()))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_elem.get_value())
        else:
            iterator = dataset_ops.make_initializable_iterator(ds)
            next_elem = iterator_ops.get_next_as_optional(iterator)
            self.assertIsInstance(next_elem, optional_ops.Optional)
            self.assertTrue(
                structure.are_compatible(
                    next_elem.element_spec,
                    structure.type_spec_from_value(tf_value_fn())))
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError. This is only relevant in graph mode.
            elem_has_value_t = next_elem.has_value()
            elem_value_t = next_elem.get_value()
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_value_t)
            # Now we initialize the iterator.
            self.evaluate(iterator.initializer)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                elem_has_value, elem_value = self.evaluate(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self.assertValuesEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(elem_value_t)
Пример #2
0
 def testIsCompatibleWithStructure(self, original_value_fn,
                                   compatible_values_fn,
                                   incompatible_values_fn):
     original_value = original_value_fn()
     compatible_values = compatible_values_fn()
     incompatible_values = incompatible_values_fn()
     s = structure.type_spec_from_value(original_value)
     for compatible_value in compatible_values:
         self.assertTrue(
             structure.are_compatible(
                 s, structure.type_spec_from_value(compatible_value)))
     for incompatible_value in incompatible_values:
         self.assertFalse(
             structure.are_compatible(
                 s, structure.type_spec_from_value(incompatible_value)))
Пример #3
0
    def testCopySparseTensorsToDeviceWithPrefetch(self):
        def make_tensor(i):
            return sparse_tensor.SparseTensorValue(indices=[[0, 0]],
                                                   values=(i * [1]),
                                                   dense_shape=[2, 2])

        host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)

        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            for i in range(10):
                actual = self.evaluate(next_element)
                self.assertAllEqual([i], actual.values)
                self.assertAllEqual([[0, 0]], actual.indices)
                self.assertAllEqual([2, 2], actual.dense_shape)
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)
Пример #4
0
    def testCopyToDeviceWithReInitAndPrefetch(self):
        host_dataset = dataset_ops.Dataset.range(10)
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_initializable_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int64, next_element.dtype)
        self.assertEqual([], next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.evaluate(iterator.initializer)
            for i in range(5):
                self.assertEqual(i, self.evaluate(next_element))
            self.evaluate(iterator.initializer)
            for i in range(10):
                self.assertEqual(i, self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)
Пример #5
0
    def checkDatasetSpec(self, tf_value, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value)
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.TensorShape([])],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value)],
                                       requires_initialization=True)
Пример #6
0
    def assertDatasetsEqual(self, dataset1, dataset2):
        """Checks that datasets are equal. Supports both graph and eager mode."""
        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset1),
                                     dataset_ops.get_structure(dataset2)))

        flattened_types = nest.flatten(
            dataset_ops.get_legacy_output_types(dataset1))

        next1 = self.getNext(dataset1)
        next2 = self.getNext(dataset2)

        while True:
            try:
                op1 = self.evaluate(next1())
            except errors.OutOfRangeError:
                with self.assertRaises(errors.OutOfRangeError):
                    self.evaluate(next2())
                break
            op2 = self.evaluate(next2())

            op1 = nest.flatten(op1)
            op2 = nest.flatten(op2)
            assert len(op1) == len(op2)
            for i in range(len(op1)):
                if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(
                        op1[i]):
                    self.assertValuesEqual(op1[i], op2[i])
                elif flattened_types[i] == dtypes.string:
                    self.assertAllEqual(op1[i], op2[i])
                else:
                    self.assertAllClose(op1[i], op2[i])
Пример #7
0
    def testIteratorStringHandle(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})
Пример #8
0
def choose_from_datasets_v2(datasets,
                            choice_dataset,
                            stop_on_empty_dataset=False):
    """Creates a dataset that deterministically chooses elements from `datasets`.

  For example, given the following datasets:

  ```python
  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
              tf.data.Dataset.from_tensors("bar").repeat(),
              tf.data.Dataset.from_tensors("baz").repeat()]

  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
  choice_dataset = tf.data.Dataset.range(3).repeat(3)

  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
  ```

  The elements of `result` will be:

  ```
  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
  ```

  Args:
    datasets: A non-empty list of `tf.data.Dataset` objects with compatible
      structure.
    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between `0`
      and `len(datasets) - 1`.
    stop_on_empty_dataset: If `True`, selection stops if it encounters an empty
      dataset. If `False`, it skips empty datasets. It is recommended to set it
      to `True`. Otherwise, the selected elements start off as the user intends,
      but may change as input datasets become empty. This can be difficult to
      detect since the dataset starts off looking correct. Default to `False`
      for backward compatibility.

  Returns:
    A dataset that interleaves elements from `datasets` according to the values
    of `choice_dataset`.

  Raises:
    TypeError: If `datasets` or `choice_dataset` has the wrong type.
    ValueError: If `datasets` is empty.
  """
    if not datasets:
        raise ValueError("`datasets` must be a non-empty list of datasets.")
    if choice_dataset is None or not structure.are_compatible(
            choice_dataset.element_spec,
            tensor_spec.TensorSpec([], dtypes.int64)):
        raise TypeError("`choice_dataset` must be a dataset of scalar "
                        "`tf.int64` tensors.")
    # pylint: disable=protected-access
    return dataset_ops._DirectedInterleaveDataset(choice_dataset, datasets,
                                                  stop_on_empty_dataset)
Пример #9
0
    def testOptionalStructure(self, tf_value_fn, expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            structure.are_compatible(opt.value_structure,
                                     expected_value_structure))

        opt_structure = structure.type_spec_from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
        self.assertTrue(
            structure.are_compatible(opt_structure._value_structure,
                                     expected_value_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(opt_structure))
        self.assertEqual([tensor_shape.scalar()],
                         structure.get_flat_tensor_shapes(opt_structure))

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.type_spec_from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self._assertElementValueEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self._assertElementValueEqual(
                self.evaluate(tf_value),
                self.evaluate(round_trip_opt.get_value()))
    def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                              expected_output_classes, expected_output_types,
                              expected_output_shapes):
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
Пример #11
0
  def _assert_datasets_equal(self, ds1, ds2):
    # First lets assert the structure is the same.
    self.assertTrue(
        structure.are_compatible(ds1.element_spec, ds2.element_spec))

    # Now create iterators on both and assert they produce the same values.
    it1 = dataset_ops.make_initializable_iterator(ds1)
    it2 = dataset_ops.make_initializable_iterator(ds2)

    get_next1 = it1.get_next()
    get_next2 = it2.get_next()

    with self.cached_session():
      self.evaluate([it1.initializer, it2.initializer])
      val1, val2 = self.evaluate([get_next1, get_next2])
      self.assertEqual(val1, val2)
Пример #12
0
def choose_from_datasets_v2(datasets, choice_dataset):
    """Creates a dataset that deterministically chooses elements from `datasets`.

  For example, given the following datasets:

  ```python
  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
              tf.data.Dataset.from_tensors("bar").repeat(),
              tf.data.Dataset.from_tensors("baz").repeat()]

  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
  choice_dataset = tf.data.Dataset.range(3).repeat(3)

  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
  ```

  The elements of `result` will be:

  ```
  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
  ```

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
      `0` and `len(datasets) - 1`.

  Returns:
    A dataset that interleaves elements from `datasets` according to the values
    of `choice_dataset`.

  Raises:
    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
      type.
  """
    if not structure.are_compatible(
            choice_dataset.element_spec,
            structure.TensorStructure(dtypes.int64, [])):
        raise TypeError("`choice_dataset` must be a dataset of scalar "
                        "`tf.int64` tensors.")
    return _DirectedInterleaveDataset(choice_dataset, datasets)
Пример #13
0
    def testCopyToDeviceInt32(self):
        host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
        device_dataset = host_dataset.apply(
            prefetching_ops.copy_to_device("/cpu:1"))

        with ops.device("/cpu:1"):
            iterator = dataset_ops.make_one_shot_iterator(device_dataset)
            next_element = iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(host_dataset),
                dataset_ops.get_structure(device_dataset)))

        self.assertEqual(dtypes.int32, next_element.dtype)
        self.assertEqual((4, ), next_element.shape)

        worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
        with self.test_session(config=worker_config):
            self.assertAllEqual([0, 1, 2, 3], self.evaluate(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                self.evaluate(next_element)
  def testPrefetchDictToDevice(self):
    host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device("/cpu:1"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(structure.are_compatible(
        dataset_ops.get_structure(host_dataset),
        dataset_ops.get_structure(device_dataset)))

    self.assertEqual(dtypes.int64, next_element["a"].dtype)
    self.assertEqual([], next_element["a"].shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual({"a": i}, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
Пример #15
0
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
  def testPrefetchToSameDevice(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device(
            "/job:localhost/replica:0/task:0/device:CPU:0"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertTrue(structure.are_compatible(
        dataset_ops.get_structure(host_dataset),
        dataset_ops.get_structure(device_dataset)))

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
    with self.test_session(config=worker_config):
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
Пример #17
0
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))
Пример #18
0
  def __init__(self, input_dataset, features, num_parallel_calls):
    self._input_dataset = input_dataset
    if not structure.are_compatible(
        input_dataset.element_spec,
        tensor_spec.TensorSpec([None], dtypes.string)):
      raise TypeError("Input dataset should be a dataset of vectors of strings")
    self._num_parallel_calls = num_parallel_calls
    # pylint: disable=protected-access
    self._features = parsing_ops._prepend_none_dimension(features)
    # sparse_keys and dense_keys come back sorted here.
    (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     dense_shapes) = parsing_ops._features_to_raw_params(
         self._features, [
             parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
             parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
         ])
    # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
    (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     dense_shape_as_shape) = parsing_ops._process_raw_parameters(
         None, dense_defaults, sparse_keys, sparse_types, dense_keys,
         dense_types, dense_shapes)
    # pylint: enable=protected-access
    self._sparse_keys = sparse_keys
    self._sparse_types = sparse_types
    self._dense_keys = dense_keys
    self._dense_defaults = dense_defaults_vec
    self._dense_shapes = dense_shapes
    self._dense_types = dense_types
    input_dataset_shape = dataset_ops.get_legacy_output_shapes(
        self._input_dataset)
    dense_output_shapes = [input_dataset_shape.concatenate(shape)
                           for shape in dense_shape_as_shape]
    sparse_output_shapes = [input_dataset_shape.concatenate([None])
                            for _ in range(len(sparse_keys))]

    output_shapes = dict(
        zip(self._dense_keys + self._sparse_keys,
            dense_output_shapes + sparse_output_shapes))
    output_types = dict(
        zip(self._dense_keys + self._sparse_keys,
            self._dense_types + self._sparse_types))
    output_classes = dict(
        zip(self._dense_keys + self._sparse_keys,
            [ops.Tensor for _ in range(len(self._dense_defaults))] +
            [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
            ]))
    self._element_spec = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)

    if compat.forward_compatible(2019, 8, 3):
      variant_tensor = (
          gen_experimental_dataset_ops.parse_example_dataset(
              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
              self._num_parallel_calls,
              self._dense_defaults,
              self._sparse_keys,
              self._dense_keys,
              self._sparse_types,
              self._dense_shapes,
              **self._flat_structure))
    else:
      variant_tensor = (
          gen_experimental_dataset_ops.experimental_parse_example_dataset(
              self._input_dataset._variant_tensor,  # pylint: disable=protected-access
              self._num_parallel_calls,
              self._dense_defaults,
              self._sparse_keys,
              self._dense_keys,
              self._sparse_types,
              self._dense_shapes,
              **self._flat_structure))
    super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
Пример #19
0
    def __init__(self, input_dataset, features, num_parallel_calls,
                 deterministic):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                tensor_spec.TensorSpec([None], dtypes.string)):
            raise TypeError(
                "Input dataset should be a dataset of vectors of strings")
        self._num_parallel_calls = num_parallel_calls
        if deterministic is None:
            self._deterministic = "default"
        elif deterministic:
            self._deterministic = "true"
        else:
            self._deterministic = "false"
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
        params = parsing_ops._ParseOpParams.from_features(
            self._features, [
                parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                parsing_ops.FixedLenFeature,
                parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature
            ])
        # pylint: enable=protected-access
        self._sparse_keys = params.sparse_keys
        self._sparse_types = params.sparse_types
        self._ragged_keys = params.ragged_keys
        self._ragged_value_types = params.ragged_value_types
        self._ragged_split_types = params.ragged_split_types
        self._dense_keys = params.dense_keys
        self._dense_defaults = params.dense_defaults_vec
        self._dense_shapes = params.dense_shapes_as_proto
        self._dense_types = params.dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)

        self._element_spec = {}

        for (key, value_type) in zip(params.sparse_keys, params.sparse_types):
            self._element_spec[key] = sparse_tensor.SparseTensorSpec(
                input_dataset_shape.concatenate([None]), value_type)

        for (key, value_type, dense_shape) in zip(params.dense_keys,
                                                  params.dense_types,
                                                  params.dense_shapes):
            self._element_spec[key] = tensor_spec.TensorSpec(
                input_dataset_shape.concatenate(dense_shape), value_type)

        for (key, value_type, splits_type) in zip(params.ragged_keys,
                                                  params.ragged_value_types,
                                                  params.ragged_split_types):
            self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
                input_dataset_shape.concatenate([None]), value_type, 1,
                splits_type)

        variant_tensor = (
            gen_experimental_dataset_ops.parse_example_dataset_v2(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._num_parallel_calls,
                self._dense_defaults,
                self._sparse_keys,
                self._dense_keys,
                self._sparse_types,
                self._dense_shapes,
                deterministic=self._deterministic,
                ragged_keys=self._ragged_keys,
                ragged_value_types=self._ragged_value_types,
                ragged_split_types=self._ragged_split_types,
                **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)