Beispiel #1
0
    def testNestedNestedStructure(self):
        s = (structure.TensorStructure(dtypes.int64, []),
             (structure.TensorStructure(dtypes.float32, []),
              structure.TensorStructure(dtypes.string, [])))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)
Beispiel #2
0
    def testNestedNestedStructure(self):
        # Although `Structure.from_value()` will not construct one, a nested
        # structure containing nested `NestedStructure` objects can occur if a
        # structure is constructed manually.
        s = structure.NestedStructure(
            (structure.TensorStructure(dtypes.int64, []),
             structure.NestedStructure(
                 (structure.TensorStructure(dtypes.float32, []),
                  structure.TensorStructure(dtypes.string, [])))))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = s._to_tensor_list(nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t, (actual_float32_t,
                          actual_string_t)) = s._from_tensor_list(tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (s._from_compatible_tensor_list(tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)
Beispiel #3
0
  def _make_window_size_func(self, window_size_func):
    """Make wrapping defun for window_size_func."""

    def window_size_func_wrapper(key):
      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
    self._window_size_func = dataset_ops.StructuredFunctionWrapper(
        window_size_func_wrapper,
        self._transformation_name(),
        input_structure=structure.TensorStructure(dtypes.int64, []))
    if not self._window_size_func.output_structure.is_compatible_with(
        structure.TensorStructure(dtypes.int64, [])):
      raise ValueError(
          "`window_size_func` must return a single tf.int64 scalar tensor.")
Beispiel #4
0
 def testFromNone(self):
   value_structure = structure.TensorStructure(dtypes.float32, [])
   opt = optional_ops.Optional.none_from_structure(value_structure)
   self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
   self.assertFalse(
       opt.value_structure.is_compatible_with(
           structure.TensorStructure(dtypes.float32, [1])))
   self.assertFalse(
       opt.value_structure.is_compatible_with(
           structure.TensorStructure(dtypes.int32, [])))
   self.assertFalse(self.evaluate(opt.has_value()))
   with self.assertRaises(errors.InvalidArgumentError):
     self.evaluate(opt.get_value())
Beispiel #5
0
    def testImportedFunctionsRegistered(self):
        if test.is_built_with_gpu_support():
            self.skipTest(
                "Disabling this new test due to errors with cuda and rocm")

        with ops.Graph().as_default() as graph:
            x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
            ds = dataset_ops.from_variant(x,
                                          structure=(structure.TensorStructure(
                                              dtypes.int32, [])))
            y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32),
                          lambda p, q: p + q)

        graph_def = graph.as_graph_def()

        def fn_to_wrap(a):
            returned_elements = graph_def_importer.import_graph_def(
                graph_def, input_map={x.name: a}, return_elements=[y.name])
            return returned_elements[0]

        wrapped_fn = wrap_function.wrap_function(
            fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
        ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
        v = dataset_ops.to_variant(ds)
        self.evaluate(wrapped_fn(v))
Beispiel #6
0
    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))
Beispiel #7
0
    def write(self, dataset):
        """Returns a `tf.Operation` to write a dataset to a file.

    Args:
      dataset: a `tf.data.Dataset` whose elements are to be written to a file

    Returns:
      A `tf.Operation` that, when run, writes contents of `dataset` to a file.
    """
        if not isinstance(dataset, dataset_ops.DatasetV2):
            raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
        if not dataset_ops.get_structure(dataset).is_compatible_with(
                structure.TensorStructure(dtypes.string, [])):
            raise TypeError(
                "`dataset` must produce scalar `DT_STRING` tensors whereas it "
                "produces shape {0} and types {1}".format(
                    dataset_ops.get_legacy_output_shapes(dataset),
                    dataset_ops.get_legacy_output_types(dataset)))
        if compat.forward_compatible(2019, 8, 3):
            return gen_experimental_dataset_ops.dataset_to_tf_record(
                dataset._variant_tensor, self._filename,
                self._compression_type)  # pylint: disable=protected-access
        else:
            return gen_experimental_dataset_ops.experimental_dataset_to_tf_record(
                dataset._variant_tensor, self._filename,
                self._compression_type)  # pylint: disable=protected-access
    def __init__(self, input_dataset, predicate):
        """See `take_while()` for details."""

        self._input_dataset = input_dataset
        wrapped_func = dataset_ops.StructuredFunctionWrapper(
            predicate,
            "tf.data.experimental.take_while()",
            dataset=self._input_dataset)

        if not wrapped_func.output_structure.is_compatible_with(
                structure_lib.TensorStructure(dtypes.bool, [])):
            raise ValueError(
                "`predicate` must return a scalar boolean tensor.")

        self._predicate = wrapped_func
        if compat.forward_compatible(2019, 8, 3):
            var_tensor = gen_experimental_dataset_ops.take_while_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                other_arguments=self._predicate.function.captured_inputs,
                predicate=self._predicate.function,
                **self._flat_structure)
        else:
            var_tensor = gen_experimental_dataset_ops.experimental_take_while_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                other_arguments=self._predicate.function.captured_inputs,
                predicate=self._predicate.function,
                **self._flat_structure)
        super(_TakeWhileDataset, self).__init__(input_dataset, var_tensor)
Beispiel #9
0
 def _make_key_func(self, key_func, input_dataset):
     """Make wrapping defun for key_func."""
     self._key_func = dataset_ops.StructuredFunctionWrapper(
         key_func, self._transformation_name(), dataset=input_dataset)
     if not self._key_func.output_structure.is_compatible_with(
             structure.TensorStructure(dtypes.int64, [])):
         raise ValueError(
             "`key_func` must return a single tf.int64 tensor. "
             "Got type=%s and shape=%s" %
             (self._key_func.output_types, self._key_func.output_shapes))
Beispiel #10
0
  def _make_key_func(self, key_func, input_dataset):
    """Make wrapping defun for key_func."""

    def key_func_wrapper(*args):
      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
    self._key_func = dataset_ops.StructuredFunctionWrapper(
        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
    if not self._key_func.output_structure.is_compatible_with(
        structure.TensorStructure(dtypes.int64, [])):
      raise ValueError(
          "`key_func` must return a single tf.int64 scalar tensor.")
Beispiel #11
0
  def __init__(self, client_resource, selected_fields, output_types,
               avro_schema, stream):
    self._structure = structure.NestedStructure(
        tuple(structure.TensorStructure(dtype, []) for dtype in output_types))

    variant_tensor = _bigquery_so.big_query_dataset(
        client=client_resource,
        selected_fields=selected_fields,
        output_types=output_types,
        avro_schema=avro_schema,
        stream=stream)
    super(_BigQueryDataset, self).__init__(variant_tensor)
Beispiel #12
0
  def __init__(self, input_dataset, features, num_parallel_calls):
    super(_ParseExampleDataset, self).__init__(input_dataset)
    self._input_dataset = input_dataset
    if not input_dataset._element_structure.is_compatible_with(  # pylint: disable=protected-access
        structure.TensorStructure(dtypes.string, [None])):
      raise TypeError("Input dataset should be a dataset of vectors of strings")
    self._num_parallel_calls = num_parallel_calls
    # pylint: disable=protected-access
    self._features = parsing_ops._prepend_none_dimension(features)
    # sparse_keys and dense_keys come back sorted here.
    (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     dense_shapes) = parsing_ops._features_to_raw_params(
         self._features, [
             parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
             parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
         ])
    # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
    (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     dense_shape_as_shape) = parsing_ops._process_raw_parameters(
         None, dense_defaults, sparse_keys, sparse_types, dense_keys,
         dense_types, dense_shapes)
    # pylint: enable=protected-access
    self._sparse_keys = sparse_keys
    self._sparse_types = sparse_types
    self._dense_keys = dense_keys
    self._dense_defaults = dense_defaults_vec
    self._dense_shapes = dense_shapes
    self._dense_types = dense_types
    dense_output_shapes = [
        self._input_dataset.output_shapes.concatenate(shape)
        for shape in dense_shape_as_shape
    ]
    sparse_output_shapes = [
        self._input_dataset.output_shapes.concatenate([None])
        for _ in range(len(sparse_keys))
    ]

    output_shapes = dict(
        zip(self._dense_keys + self._sparse_keys,
            dense_output_shapes + sparse_output_shapes))
    output_types = dict(
        zip(self._dense_keys + self._sparse_keys,
            self._dense_types + self._sparse_types))
    output_classes = dict(
        zip(self._dense_keys + self._sparse_keys,
            [ops.Tensor for _ in range(len(self._dense_defaults))] +
            [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
            ]))
    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
Beispiel #13
0
 def _make_reduce_func(self, reduce_func, input_dataset):
   """Make wrapping defun for reduce_func."""
   nested_dataset = dataset_ops.DatasetStructure(
       input_dataset._element_structure)  # pylint: disable=protected-access
   input_structure = structure.NestedStructure(
       (structure.TensorStructure(dtypes.int64, []), nested_dataset))
   self._reduce_func = dataset_ops.StructuredFunctionWrapper(
       reduce_func, self._transformation_name(),
       input_structure=input_structure)
   if not isinstance(
       self._reduce_func.output_structure, dataset_ops.DatasetStructure):
     raise TypeError("`reduce_func` must return a `Dataset` object.")
   # pylint: disable=protected-access
   self._structure = (
       self._reduce_func.output_structure._element_structure)
Beispiel #14
0
    def __init__(self, driver_name, data_source_name, query, output_types):
        """Creates a `SqlDataset`.

    `SqlDataset` allows a user to read data from the result set of a SQL query.
    For example:

    ```python
    tf.compat.v1.enable_eager_execution()

    dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
                                              "SELECT name, age FROM people",
                                              (tf.string, tf.int32))
    # Prints the rows of the result set of the above query.
    for element in dataset:
      print(element)
    ```

    Args:
      driver_name: A 0-D `tf.string` tensor containing the database type.
        Currently, the only supported value is 'sqlite'.
      data_source_name: A 0-D `tf.string` tensor containing a connection string
        to connect to the database.
      query: A 0-D `tf.string` tensor containing the SQL query to execute.
      output_types: A tuple of `tf.DType` objects representing the types of the
        columns returned by `query`.
    """
        self._driver_name = ops.convert_to_tensor(driver_name,
                                                  dtype=dtypes.string,
                                                  name="driver_name")
        self._data_source_name = ops.convert_to_tensor(data_source_name,
                                                       dtype=dtypes.string,
                                                       name="data_source_name")
        self._query = ops.convert_to_tensor(query,
                                            dtype=dtypes.string,
                                            name="query")
        self._structure = structure.NestedStructure(
            nest.map_structure(
                lambda dtype: structure.TensorStructure(dtype, []),
                output_types))
        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = gen_experimental_dataset_ops.sql_dataset(
                self._driver_name, self._data_source_name, self._query,
                **self._flat_structure)
        else:
            variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
                self._driver_name, self._data_source_name, self._query,
                **self._flat_structure)
        super(SqlDatasetV2, self).__init__(variant_tensor)
Beispiel #15
0
def choose_from_datasets_v2(datasets, choice_dataset):
    """Creates a dataset that deterministically chooses elements from `datasets`.

  For example, given the following datasets:

  ```python
  datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
              tf.data.Dataset.from_tensors("bar").repeat(),
              tf.data.Dataset.from_tensors("baz").repeat()]

  # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
  choice_dataset = tf.data.Dataset.range(3).repeat(3)

  result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
  ```

  The elements of `result` will be:

  ```
  "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
  ```

  Args:
    datasets: A list of `tf.data.Dataset` objects with compatible structure.
    choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
      `0` and `len(datasets) - 1`.

  Returns:
    A dataset that interleaves elements from `datasets` according to the values
    of `choice_dataset`.

  Raises:
    TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
      type.
  """
    if not structure.are_compatible(
            choice_dataset.element_spec,
            structure.TensorStructure(dtypes.int64, [])):
        raise TypeError("`choice_dataset` must be a dataset of scalar "
                        "`tf.int64` tensors.")
    return _DirectedInterleaveDataset(choice_dataset, datasets)
Beispiel #16
0
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testFromValue(self):
        opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual(37.0, self.evaluate(opt.get_value()))

    def testFromStructuredValue(self):
        opt = optional_ops.Optional.from_value({
            "a":
            constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        })
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual({
            "a": 37.0,
            "b": ([b"Foo"], b"Bar")
        }, self.evaluate(opt.get_value()))

    def testFromSparseTensor(self):
        st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                               values=np.array([0],
                                                               dtype=np.int64),
                                               dense_shape=np.array([1]))
        st_1 = sparse_tensor.SparseTensorValue(
            indices=np.array([[0, 0], [1, 1]]),
            values=np.array([-1., 1.], dtype=np.float32),
            dense_shape=np.array([2, 2]))
        opt = optional_ops.Optional.from_value((st_0, st_1))
        self.assertTrue(self.evaluate(opt.has_value()))
        val_0, val_1 = opt.get_value()
        for expected, actual in [(st_0, val_0), (st_1, val_1)]:
            self.assertAllEqual(expected.indices,
                                self.evaluate(actual.indices))
            self.assertAllEqual(expected.values, self.evaluate(actual.values))
            self.assertAllEqual(expected.dense_shape,
                                self.evaluate(actual.dense_shape))

    def testFromNone(self):
        value_structure = structure.TensorStructure(dtypes.float32, [])
        opt = optional_ops.Optional.none_from_structure(value_structure)
        self.assertTrue(
            opt.value_structure.is_compatible_with(value_structure))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.float32, [1])))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.int32, [])))
        self.assertFalse(self.evaluate(opt.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(opt.get_value())

    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))

    def _assertElementValueEqual(self, expected, actual):
        if isinstance(expected, dict):
            self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
            for k in expected.keys():
                self._assertElementValueEqual(expected[k], actual[k])
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.assertAllEqual(expected.indices, actual.indices)
            self.assertAllEqual(expected.values, actual.values)
            self.assertAllEqual(expected.dense_shape, actual.dense_shape)
        else:
            self.assertAllEqual(expected, actual)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testSkipEagerOptionalStructure(self, tf_value_fn,
                                       expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            expected_value_structure.is_compatible_with(opt.value_structure))
        self.assertTrue(
            opt.value_structure.is_compatible_with(expected_value_structure))

        opt_structure = structure.Structure.from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(opt_structure.is_compatible_with(opt_structure))
        self.assertTrue(
            opt_structure._value_structure.is_compatible_with(
                expected_value_structure))
        self.assertEqual([dtypes.variant], opt_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.Structure.from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self.assertEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self.assertEqual(self.evaluate(tf_value),
                             self.evaluate(round_trip_opt.get_value()))

    @parameterized.named_parameters(
        ("Tensor", np.array([1, 2, 3], dtype=np.int32),
         lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
        ("SparseTensor",
         sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                         values=np.array([-1., 1.],
                                                         dtype=np.float32),
                                         dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                            values=[37.0, 42.0],
                                            dense_shape=[2, 2]), False),
        ("Nest", {
            "a":
            np.array([1, 2, 3], dtype=np.int32),
            "b":
            sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                            values=np.array([-1., 1.],
                                                            dtype=np.float32),
                                            dense_shape=[2, 2])
        }, lambda: {
            "a":
            constant_op.constant([4, 5, 6], dtype=dtypes.int32),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                       values=[37.0, 42.0],
                                       dense_shape=[2, 2])
        }, False),
    )
    def testSkipEagerIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                               works_on_gpu):
        if not works_on_gpu and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
        iterator = ds.make_initializable_iterator()
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertIsInstance(next_elem, optional_ops.Optional)
        self.assertTrue(
            next_elem.value_structure.is_compatible_with(
                structure.Structure.from_value(tf_value_fn())))
        elem_has_value_t = next_elem.has_value()
        elem_value_t = next_elem.get_value()
        with self.cached_session() as sess:
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_value_t)

            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            sess.run(iterator.initializer)
            for _ in range(3):
                elem_has_value, elem_value = sess.run(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self._assertElementValueEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(sess.run(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    sess.run(elem_value_t)
Beispiel #17
0
 def element_spec(self):
     return structure.TensorStructure(dtypes.string, [])
Beispiel #18
0
  def __init__(self,
               filenames,
               record_defaults,
               compression_type=None,
               buffer_size=None,
               header=False,
               field_delim=",",
               use_quote_delim=True,
               na_value="",
               select_cols=None):
    """Creates a `CsvDataset` by reading and decoding CSV files.

    The elements of this dataset correspond to records from the file(s).
    RFC 4180 format is expected for CSV files
    (https://tools.ietf.org/html/rfc4180)
    Note that we allow leading and trailing spaces with int or float field.


    For example, suppose we have a file 'my_file0.csv' with four CSV columns of
    different data types:
    ```
    abcdefg,4.28E10,5.55E6,12
    hijklmn,-5.3E14,,2
    ```

    We can construct a CsvDataset from it as follows:

    ```python
    tf.compat.v1.enable_eager_execution()

     dataset = tf.data.experimental.CsvDataset(
        "my_file*.csv",
        [tf.float32,  # Required field, use dtype or empty tensor
         tf.constant([0.0], dtype=tf.float32),  # Optional field, default to 0.0
         tf.int32,  # Required field, use dtype or empty tensor
         ],
        select_cols=[1,2,3]  # Only parse last three columns
    )
    ```

    The expected output of its iterations is:

    ```python
    for element in dataset:
      print(element)

    >> (4.28e10, 5.55e6, 12)
    >> (-5.3e14, 0.0, 2)
    ```

    Args:
      filenames: A `tf.string` tensor containing one or more filenames.
      record_defaults: A list of default values for the CSV fields. Each item in
        the list is either a valid CSV `DType` (float32, float64, int32, int64,
        string), or a `Tensor` object with one of the above types. One per
        column of CSV data, with either a scalar `Tensor` default value for the
        column if it is optional, or `DType` or empty `Tensor` if required. If
        both this and `select_columns` are specified, these must have the same
        lengths, and `column_defaults` is assumed to be sorted in order of
        increasing column index.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
        compression.
      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
        to buffer while reading files. Defaults to 4MB.
      header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
        have header line(s) that should be skipped when parsing. Defaults to
        `False`.
      field_delim: (Optional.) A `tf.string` scalar containing the delimiter
        character that separates fields in a record. Defaults to `","`.
      use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
        double quotation marks as regular characters inside of string fields
        (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
      na_value: (Optional.) A `tf.string` scalar indicating a value that will
        be treated as NA/NaN.
      select_cols: (Optional.) A sorted list of column indices to select from
        the input data. If specified, only this subset of columns will be
        parsed. Defaults to parsing all columns.
    """
    self._filenames = ops.convert_to_tensor(
        filenames, dtype=dtypes.string, name="filenames")
    self._compression_type = convert.optional_param_to_tensor(
        "compression_type",
        compression_type,
        argument_default="",
        argument_dtype=dtypes.string)
    record_defaults = [
        constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
        for x in record_defaults
    ]
    self._record_defaults = ops.convert_n_to_tensor(
        record_defaults, name="record_defaults")
    self._buffer_size = convert.optional_param_to_tensor(
        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
    self._header = ops.convert_to_tensor(
        header, dtype=dtypes.bool, name="header")
    self._field_delim = ops.convert_to_tensor(
        field_delim, dtype=dtypes.string, name="field_delim")
    self._use_quote_delim = ops.convert_to_tensor(
        use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
    self._na_value = ops.convert_to_tensor(
        na_value, dtype=dtypes.string, name="na_value")
    self._select_cols = convert.optional_param_to_tensor(
        "select_cols",
        select_cols,
        argument_default=[],
        argument_dtype=dtypes.int64,
    )
    self._structure = structure.NestedStructure(
        tuple(structure.TensorStructure(d.dtype, [])
              for d in self._record_defaults))
    variant_tensor = gen_experimental_dataset_ops.experimental_csv_dataset(
        filenames=self._filenames,
        record_defaults=self._record_defaults,
        buffer_size=self._buffer_size,
        header=self._header,
        output_shapes=self._structure._flat_shapes,  # pylint: disable=protected-access
        field_delim=self._field_delim,
        use_quote_delim=self._use_quote_delim,
        na_value=self._na_value,
        select_cols=self._select_cols,
        compression_type=self._compression_type)
    super(CsvDatasetV2, self).__init__(variant_tensor)
Beispiel #19
0
 def _element_structure(self):
     return (structure.TensorStructure(dtypes.string, []),
             structure.TensorStructure(dtypes.string, []))
Beispiel #20
0
 def _element_structure(self):
     return structure.NestedStructure(
         tuple([structure.TensorStructure(dtypes.string, [])] *
               self._num_outputs))
Beispiel #21
0
 def _element_structure(self):
     return structure.TensorStructure(dtypes.int64, [])
Beispiel #22
0
 def consume_optional(opt_tensor):
     value_structure = structure.TensorStructure(dtypes.float32, [])
     opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
     return opt.get_value()
Beispiel #23
0
class IteratorTest(test.TestCase, parameterized.TestCase):

  def testNoGradients(self):
    component = constant_op.constant([1.])
    side = constant_op.constant(0.)
    add = lambda x: x + side
    dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
    value = dataset.make_one_shot_iterator().get_next()
    self.assertIsNone(gradients_impl.gradients(value, component)[0])
    self.assertIsNone(gradients_impl.gradients(value, side)[0])
    self.assertIsNone(gradients_impl.gradients(value, [component, side])[0])

  def testCapturingStateInOneShotRaisesException(self):
    var = variables.Variable(37.0, name="myvar")
    dataset = (
        dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
        .map(lambda x: x + var))
    with self.assertRaisesRegexp(
        ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
        "datasets that capture stateful objects.+myvar"):
      dataset.make_one_shot_iterator()

  def testOneShotIterator(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    iterator = (
        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
        .repeat(14).make_one_shot_iterator())
    get_next = iterator.get_next()

    self.assertEqual([c.shape[1:] for c in components],
                     [t.shape for t in get_next])

    with self.cached_session() as sess:
      for _ in range(14):
        for i in range(7):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component[i]**2, result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testOneShotIteratorCaptureByValue(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    tensor_components = tuple([ops.convert_to_tensor(c) for c in components])

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    iterator = (
        dataset_ops.Dataset.from_tensor_slices(tensor_components)
        .map(_map_fn).repeat(14).make_one_shot_iterator())
    get_next = iterator.get_next()

    self.assertEqual([c.shape[1:] for c in components],
                     [t.shape for t in get_next])

    with self.cached_session() as sess:
      for _ in range(14):
        for i in range(7):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component[i]**2, result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testOneShotIteratorInsideContainer(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))

    def within_container():

      def _map_fn(x, y, z):
        return math_ops.square(x), math_ops.square(y), math_ops.square(z)

      iterator = (
          dataset_ops.Dataset.from_tensor_slices(components)
          .map(_map_fn).repeat(14).make_one_shot_iterator())
      return iterator.get_next()

    server = server_lib.Server.create_local_server()

    # Create two iterators within unique containers, and run them to
    # make sure that the resources aren't shared.
    #
    # The test below would fail if cname were the same across both
    # sessions.
    for j in range(2):
      with session.Session(server.target) as sess:
        cname = "iteration%d" % j
        with ops.container(cname):
          get_next = within_container()

        for _ in range(14):
          for i in range(7):
            result = sess.run(get_next)
            for component, result_component in zip(components, result):
              self.assertAllEqual(component[i]**2, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

  def testOneShotIteratorNonBlocking(self):
    dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x)
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    # Create a session with a single thread to ensure that the
    # one-shot iterator initializer does not deadlock.
    config = config_pb2.ConfigProto(
        inter_op_parallelism_threads=1, use_per_session_threads=True)
    with session.Session(config=config) as sess:
      self.assertAllEqual([1, 4, 9], sess.run(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_element)

    # Test with multiple threads invoking the one-shot iterator concurrently.
    with session.Session(config=config) as sess:
      results = []

      def consumer_thread():
        try:
          results.append(sess.run(next_element))
        except errors.OutOfRangeError:
          results.append(None)

      num_threads = 8
      threads = [
          self.checkedThread(consumer_thread) for _ in range(num_threads)
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

      self.assertEqual(num_threads, len(results))
      self.assertEqual(num_threads - 1,
                       len([None for r in results if r is None]))
      self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None])

  def testOneShotIteratorInitializerFails(self):
    # Define a dataset whose initialization will always fail.
    dataset = dataset_ops.Dataset.from_tensors(
        array_ops.check_numerics(
            constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
    iterator = dataset.make_one_shot_iterator()
    next_element = iterator.get_next()

    with self.cached_session() as sess:
      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
        sess.run(next_element)

      # Test that subsequent attempts to use the iterator also fail.
      with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
        sess.run(next_element)

    with self.cached_session() as sess:

      def consumer_thread():
        with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
          sess.run(next_element)

      num_threads = 8
      threads = [
          self.checkedThread(consumer_thread) for _ in range(num_threads)
      ]
      for t in threads:
        t.start()
      for t in threads:
        t.join()

  def testSimpleSharedResource(self):
    components = (np.array(1, dtype=np.int64),
                  np.array([1, 2, 3], dtype=np.int64),
                  np.array(37.0, dtype=np.float64))

    server = server_lib.Server.create_local_server()

    # Create two non-overlapping sessions that share the same iterator
    # resource on the same server, and verify that an action of the
    # first session (initializing the iterator) is visible in the
    # second session.
    with ops.Graph().as_default():
      iterator = (
          dataset_ops.Dataset.from_tensors(components)
          .map(lambda x, y, z: (x, y, z)).make_initializable_iterator(
              shared_name="shared_iterator"))
      init_op = iterator.initializer
      get_next = iterator.get_next()

      with session.Session(server.target) as sess:
        sess.run(init_op)
        results = sess.run(get_next)
        for component, result_component in zip(components, results):
          self.assertAllEqual(component, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

        # Re-initialize the iterator in the first session.
        sess.run(init_op)

    with ops.Graph().as_default():
      # Re-define the iterator manually, without defining any of the
      # functions in this graph, to ensure that we are not
      # accidentally redefining functions with the same names in the
      # new graph.
      iterator = iterator_ops.Iterator.from_structure(
          shared_name="shared_iterator",
          output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
          output_shapes=([], [3], []))
      get_next = iterator.get_next()

      with session.Session(server.target) as sess:
        # Use the iterator without re-initializing in the second session.
        results = sess.run(get_next)
        for component, result_component in zip(components, results):
          self.assertAllEqual(component, result_component)
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(get_next)

  def testNotInitializedError(self):
    components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
    iterator = (
        dataset_ops.Dataset.from_tensors(components)
        .make_initializable_iterator())
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      with self.assertRaisesRegexp(errors.FailedPreconditionError,
                                   "iterator has not been initialized"):
        sess.run(get_next)

  def testReinitializableIterator(self):
    dataset_3 = dataset_ops.Dataset.from_tensors(
        constant_op.constant([1, 2, 3]))
    dataset_4 = dataset_ops.Dataset.from_tensors(
        constant_op.constant([4, 5, 6, 7]))
    iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types,
                                                    [None])

    dataset_3_init_op = iterator.make_initializer(dataset_3)
    dataset_4_init_op = iterator.make_initializer(dataset_4)
    get_next = iterator.get_next()

    self.assertEqual(dataset_3.output_types, iterator.output_types)
    self.assertEqual(dataset_4.output_types, iterator.output_types)
    self.assertEqual([None], iterator.output_shapes.as_list())

    with self.cached_session() as sess:
      # The iterator is initially uninitialized.
      with self.assertRaises(errors.FailedPreconditionError):
        sess.run(get_next)

      # Initialize with one dataset.
      sess.run(dataset_3_init_op)
      self.assertAllEqual([1, 2, 3], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Initialize with a different dataset.
      sess.run(dataset_4_init_op)
      self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

      # Reinitialize with the first dataset.
      sess.run(dataset_3_init_op)
      self.assertAllEqual([1, 2, 3], sess.run(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)

  def testReinitializableIteratorStaticErrors(self):
    # Non-matching structure for types and shapes.
    with self.assertRaises(TypeError):
      iterator = iterator_ops.Iterator.from_structure(
          (dtypes.int64, dtypes.float64), [None])

    # Test validation of dataset argument.
    iterator = iterator_ops.Iterator.from_structure((dtypes.int64,
                                                     dtypes.float64))

    # Incompatible structure.
    with self.assertRaises(ValueError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(((constant_op.constant(
              [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant(
                  [4., 5., 6., 7.], dtype=dtypes.float64),))))

    # Incompatible types.
    with self.assertRaises(TypeError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(
              (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float32))))

    # Incompatible shapes.
    iterator = iterator_ops.Iterator.from_structure(
        (dtypes.int64, dtypes.float64), ([None], []))
    with self.assertRaises(TypeError):
      iterator.make_initializer(
          dataset_ops.Dataset.from_tensors(
              (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
               constant_op.constant([4., 5., 6., 7.], dtype=dtypes.float64))))

  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_3.make_one_shot_iterator()
    iterator_4 = dataset_4.make_one_shot_iterator()

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
    next_element = feedable_iterator.get_next()

    self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
    self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
    self.assertEqual([], feedable_iterator.output_shapes)

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})

  def testIteratorStringHandleFuture(self):
    with forward_compat.forward_compatibility_horizon(2018, 8, 4):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

      iterator_3 = dataset_3.make_one_shot_iterator()
      iterator_4 = dataset_4.make_one_shot_iterator()

      handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      feedable_iterator = iterator_ops.Iterator.from_string_handle(
          handle_placeholder, dataset_3.output_types, dataset_3.output_shapes)
      next_element = feedable_iterator.get_next()

      self.assertEqual(dataset_3.output_types, feedable_iterator.output_types)
      self.assertEqual(dataset_4.output_types, feedable_iterator.output_types)
      self.assertEqual([], feedable_iterator.output_shapes)

      with self.cached_session() as sess:
        iterator_3_handle = sess.run(iterator_3.string_handle())
        iterator_4_handle = sess.run(iterator_4.string_handle())

        self.assertEqual(
            10,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            1,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            20,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            2,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            30,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        self.assertEqual(
            3,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_3_handle}))
        self.assertEqual(
            40,
            sess.run(
                next_element,
                feed_dict={handle_placeholder: iterator_4_handle}))
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(
              next_element, feed_dict={handle_placeholder: iterator_3_handle})
        with self.assertRaises(errors.OutOfRangeError):
          sess.run(
              next_element, feed_dict={handle_placeholder: iterator_4_handle})

  def testIteratorStringHandleReuseTensorObject(self):
    dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    one_shot_iterator = dataset.make_one_shot_iterator()
    initializable_iterator = dataset.make_initializable_iterator()
    structure_iterator = iterator_ops.Iterator.from_structure(
        dataset.output_types)

    created_ops = len(ops.get_default_graph().get_operations())

    self.assertIs(one_shot_iterator.string_handle(),
                  one_shot_iterator.string_handle())
    self.assertIs(initializable_iterator.string_handle(),
                  initializable_iterator.string_handle())
    self.assertIs(structure_iterator.string_handle(),
                  structure_iterator.string_handle())

    # Assert that getting the (default) string handle creates no ops.
    self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))

    # Specifying an explicit name will create a new op.
    handle_with_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo", handle_with_name.op.name)
    self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

    handle_with_same_name = one_shot_iterator.string_handle(name="foo")
    self.assertEqual("foo_1", handle_with_same_name.op.name)
    self.assertIsNot(handle_with_name, handle_with_same_name)

  def testIteratorStringHandleError(self):
    dataset_int_scalar = (
        dataset_ops.Dataset.from_tensor_slices([1, 2, 3]).repeat())
    dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0]))

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

    feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32, [])
    feedable_int_vector = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32, [None])
    feedable_int_any = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dtypes.int32)

    with self.cached_session() as sess:
      handle_int_scalar = sess.run(
          dataset_int_scalar.make_one_shot_iterator().string_handle())
      handle_float_vector = sess.run(
          dataset_float_vector.make_one_shot_iterator().string_handle())

      self.assertEqual(1,
                       sess.run(
                           feedable_int_scalar.get_next(),
                           feed_dict={handle_placeholder: handle_int_scalar}))

      self.assertEqual(2,
                       sess.run(
                           feedable_int_any.get_next(),
                           feed_dict={handle_placeholder: handle_int_scalar}))

      with self.assertRaises(errors.InvalidArgumentError):
        print(sess.run(
            feedable_int_vector.get_next(),
            feed_dict={handle_placeholder: handle_int_scalar}))

      with self.assertRaises(errors.InvalidArgumentError):
        print(sess.run(
            feedable_int_vector.get_next(),
            feed_dict={handle_placeholder: handle_float_vector}))

  def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
    worker_config = config_pb2.ConfigProto()
    worker_config.device_count["CPU"] = 3

    with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      iterator_3 = dataset_3.make_one_shot_iterator()
      iterator_3_handle = iterator_3.string_handle()

    @function.Defun(dtypes.string)
    def _remote_fn(h):
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          h, dataset_3.output_types, dataset_3.output_shapes)
      return remote_iterator.get_next()

    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      remote_op = functional_ops.remote_call(
          args=[iterator_3_handle],
          Tout=[dtypes.int32],
          f=_remote_fn,
          target=target_placeholder)

    with self.session(config=worker_config) as sess:
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [1])
      # Fails when target is cpu:2 where the resource is not located.
      with self.assertRaises(errors.InvalidArgumentError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
            })
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [2])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
          })
      self.assertEqual(elem, [3])
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
            })

  def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
    s1 = server_lib.Server.create_local_server()
    s2 = server_lib.Server.create_local_server()
    s3 = server_lib.Server.create_local_server()

    cluster_def = cluster_pb2.ClusterDef()
    workers = cluster_def.job.add()
    workers.name = "worker"
    workers.tasks[0] = s1.target[len("grpc://"):]
    workers.tasks[1] = s2.target[len("grpc://"):]
    client = cluster_def.job.add()
    client.name = "client"
    client.tasks[0] = s3.target[len("grpc://"):]
    config = config_pb2.ConfigProto(cluster_def=cluster_def)

    worker_devices = [
        "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
    ]
    itr_handles = []
    for device in worker_devices:
      with ops.device(device):
        src = dataset_ops.Dataset.from_tensor_slices([device])
        itr = src.make_one_shot_iterator()
        itr_handles.append(itr.string_handle())

    targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
    handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

    @function.Defun(dtypes.string)
    def loading_func(h):
      remote_itr = iterator_ops.Iterator.from_string_handle(
          h, itr.output_types, itr.output_shapes)
      return remote_itr.get_next()

    def map_fn(target, handle):
      return functional_ops.remote_call(
          args=[handle], Tout=[dtypes.string], f=loading_func, target=target)

    with ops.device("/job:client"):
      client_dataset = dataset_ops.Dataset.zip((targets, handles)).map(map_fn)
      itr = client_dataset.make_initializable_iterator()
      n = itr.get_next()

    with session.Session(s3.target, config=config) as sess:
      sess.run(itr.initializer)
      expected_values = worker_devices
      for expected in expected_values:
        self.assertEqual((compat.as_bytes(expected),), sess.run(n))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(n)

  def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
      dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
      iterator_3 = dataset_3.make_one_shot_iterator()
      iterator_3_handle = iterator_3.string_handle()

    def _encode_raw(byte_array):
      return bytes(bytearray(byte_array))

    @function.Defun(dtypes.uint8)
    def _remote_fn(h):
      handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
      remote_iterator = iterator_ops.Iterator.from_string_handle(
          handle, dataset_3.output_types, dataset_3.output_shapes)
      return remote_iterator.get_next()

    with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
      target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
      iterator_3_handle_uint8 = parsing_ops.decode_raw(
          bytes=iterator_3_handle, out_type=dtypes.uint8)
      remote_op = functional_ops.remote_call(
          args=[iterator_3_handle_uint8],
          Tout=[dtypes.int32],
          f=_remote_fn,
          target=target_placeholder)

    with self.cached_session() as sess:
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [1])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [2])
      elem = sess.run(
          remote_op,
          feed_dict={
              target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
          })
      self.assertEqual(elem, [3])
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            remote_op,
            feed_dict={
                target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
            })

  def testIncorrectIteratorRestore(self):

    def _path():
      return os.path.join(self.get_temp_dir(), "iterator")

    def _save_op(iterator_resource):
      iterator_state_variant = gen_dataset_ops.serialize_iterator(
          iterator_resource)
      save_op = io_ops.write_file(
          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
      return save_op

    def _restore_op(iterator_resource):
      iterator_state_variant = parsing_ops.parse_tensor(
          io_ops.read_file(_path()), dtypes.variant)
      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
                                                        iterator_state_variant)
      return restore_op

    def _build_range_dataset_graph():
      start = 1
      stop = 10
      iterator = dataset_ops.Dataset.range(start,
                                           stop).make_initializable_iterator()
      init_op = iterator.initializer
      get_next = iterator.get_next()
      save_op = _save_op(iterator._iterator_resource)
      restore_op = _restore_op(iterator._iterator_resource)
      return init_op, get_next, save_op, restore_op

    def _build_reader_dataset_graph():
      filenames = ["test"]  # Does not exist but we don't care in this test.
      iterator = readers.FixedLengthRecordDataset(
          filenames, 1, 0, 0).make_initializable_iterator()
      init_op = iterator.initializer
      get_next_op = iterator.get_next()
      save_op = _save_op(iterator._iterator_resource)
      restore_op = _restore_op(iterator._iterator_resource)
      return init_op, get_next_op, save_op, restore_op

    # Saving iterator for RangeDataset graph.
    with ops.Graph().as_default() as g:
      init_op, _, save_op, _ = _build_range_dataset_graph()
      with self.session(graph=g) as sess:
        sess.run(init_op)
        sess.run(save_op)

    # Attempt to restore the saved iterator into an IteratorResource of
    # incompatible type. An iterator of RangeDataset has output type int64,
    # while an iterator of FixedLengthRecordDataset has output type string.
    # So an InvalidArgumentError should be raised by
    # IteratorResource::set_iterator.
    with ops.Graph().as_default() as g:
      _, _, _, restore_op = _build_reader_dataset_graph()
      with self.session(graph=g) as sess:
        with self.assertRaises(errors.InvalidArgumentError):
          sess.run(restore_op)

  def testRepeatedGetNextWarning(self):
    iterator = dataset_ops.Dataset.range(10).make_one_shot_iterator()
    warnings.simplefilter("always")
    with warnings.catch_warnings(record=True) as w:
      for _ in range(100):
        iterator.get_next()
    self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD, len(w))
    for warning in w:
      self.assertIn(
          iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE, str(warning.message))

  def testEagerIteratorAsync(self):
    with context.eager_mode(), context.execution_mode(context.ASYNC):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, []),
       ops.Tensor, dtypes.float32, []),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1]),
       sparse_tensor.SparseTensor, dtypes.int32, [1]),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))}),
       {"a": ops.Tensor, "b": (ops.Tensor, ops.Tensor)},
       {"a": dtypes.float32, "b": (dtypes.string, dtypes.string)},
       {"a": [], "b": ([1], [])}),
  )
  def testIteratorStructure(self, tf_value_fn, expected_element_structure,
                            expected_output_classes, expected_output_types,
                            expected_output_shapes):
    tf_value = tf_value_fn()
    iterator = dataset_ops.Dataset.from_tensors(
        tf_value).make_one_shot_iterator()

    self.assertTrue(expected_element_structure.is_compatible_with(
        iterator._element_structure))
    self.assertTrue(iterator._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual(expected_output_classes, iterator.output_classes)
    self.assertEqual(expected_output_types, iterator.output_types)
    self.assertEqual(expected_output_shapes, iterator.output_shapes)
Beispiel #24
0
 def _make_init_func(self, init_func):
     """Make wrapping defun for init_func."""
     self._init_func = dataset_ops.StructuredFunctionWrapper(
         init_func,
         self._transformation_name(),
         input_structure=structure.TensorStructure(dtypes.int64, []))
Beispiel #25
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):

    # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
    # will be executed before the (eager- or graph-mode) test environment has been
    # set up.
    # pylint: disable=g-long-lambda,protected-access
    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), structure.TensorStructure,
         [dtypes.float32], [[]]),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         structure.SparseTensorStructure, [dtypes.variant], [None]),
        (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
         structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]
                                                                       ]),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, structure.NestedStructure,
         [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.Structure.from_value(value)
        self.assertIsInstance(s, expected_structure)
        self.assertEqual(expected_types, s._flat_types)
        for expected, actual in zip(expected_shapes, s._flat_shapes):
            self.assertTrue(actual.is_compatible_with(expected))
            self.assertTrue(
                tensor_shape.as_shape(expected).is_compatible_with(actual))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.Structure.from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                s.is_compatible_with(
                    structure.Structure.from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                s.is_compatible_with(
                    structure.Structure.from_value(incompatible_value)))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), ),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ),
        (lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, ),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.Structure.from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.Structure.from_value(value_tensor)
        flat_tensor = s_tensor._to_tensor_list(value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
        flat_sparse_tensor = s_sparse_tensor._to_tensor_list(
            value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.Structure.from_value(value_nest)
        flat_nest = s_nest._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                r"Value \{.*\} is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
            s_tensor._from_tensor_list(flat_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "TensorStructure corresponds to a single tf.Tensor."):
            s_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.Structure.from_value(value_0)
        flat_s_0 = s_0._to_tensor_list(value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.Structure.from_value(value_1)
        flat_s_1 = s_1._to_tensor_list(value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.Structure.from_value(value_2)
        flat_s_2 = s_2._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure"):
            s_0._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*SparseTensorStructure"):
            s_1._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.*Tensor.* not compatible with the nested structure "
                ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
                "not compatible with the nested structure .*"
                "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
            s_0._from_tensor_list(flat_s_1)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_0._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_1._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_1._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("TensorArray0", dtypes.int32, tensor_shape.as_shape(
            [None, True, 2, 2]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
        ("TensorArray1", dtypes.int32, tensor_shape.as_shape(
            [True, None, 2, 2]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
        ("TensorArray2", dtypes.int32,
         tensor_shape.as_shape([True, False, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
        ("Nest", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertTrue(
            expected_structure.is_compatible_with(actual_structure))
        self.assertTrue(
            actual_structure.is_compatible_with(expected_structure))

    def testNestedNestedStructure(self):
        # Although `Structure.from_value()` will not construct one, a nested
        # structure containing nested `NestedStructure` objects can occur if a
        # structure is constructed manually.
        s = structure.NestedStructure(
            (structure.TensorStructure(dtypes.int64, []),
             structure.NestedStructure(
                 (structure.TensorStructure(dtypes.float32, []),
                  structure.TensorStructure(dtypes.string, [])))))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = s._to_tensor_list(nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t, (actual_float32_t,
                          actual_string_t)) = s._from_tensor_list(tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (s._from_compatible_tensor_list(tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, []), 32,
         structure.TensorStructure(dtypes.float32, [32])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None,
         structure.TensorStructure(dtypes.float32, [None])),
        ("SparseTensor", structure.SparseTensorStructure(
            dtypes.float32, [None]), 32,
         structure.SparseTensorStructure(dtypes.float32, [32, None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [4]), None,
         structure.SparseTensorStructure(dtypes.float32, [None, 4])),
        ("Nest",
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         }), 128,
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, [128]),
             "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                   structure.TensorStructure(dtypes.string, [128]))
         })),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = element_structure._batch(batch_size)
        self.assertTrue(
            batched_structure.is_compatible_with(expected_batched_structure))
        self.assertTrue(
            expected_batched_structure.is_compatible_with(batched_structure))

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, [32]),
         structure.TensorStructure(dtypes.float32, [])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor",
         structure.SparseTensorStructure(dtypes.float32, [32, None]),
         structure.SparseTensorStructure(dtypes.float32, [None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [None, 4]),
         structure.SparseTensorStructure(dtypes.float32, [4])),
        ("Nest",
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, [128]),
             "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                   structure.TensorStructure(dtypes.string, [None]))
         }),
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = element_structure._unbatch()
        self.assertTrue(
            unbatched_structure.is_compatible_with(
                expected_unbatched_structure))
        self.assertTrue(
            expected_unbatched_structure.is_compatible_with(
                unbatched_structure))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.Structure.from_value(batched_value)
        batched_tensor_list = s._to_batched_tensor_list(batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = s._unbatch()
        actual_element_0 = unbatched_s._from_tensor_list(
            [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            if sparse_tensor.is_sparse(expected):
                self.assertSparseValuesEqual(expected, actual)
            else:
                self.assertAllEqual(expected, actual)
Beispiel #26
0
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testFromValue(self):
        opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual(37.0, self.evaluate(opt.get_value()))

    def testFromStructuredValue(self):
        opt = optional_ops.Optional.from_value({
            "a":
            constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        })
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual({
            "a": 37.0,
            "b": ([b"Foo"], b"Bar")
        }, self.evaluate(opt.get_value()))

    def testFromSparseTensor(self):
        st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                               values=np.array([0],
                                                               dtype=np.int64),
                                               dense_shape=np.array([1]))
        st_1 = sparse_tensor.SparseTensorValue(
            indices=np.array([[0, 0], [1, 1]]),
            values=np.array([-1., 1.], dtype=np.float32),
            dense_shape=np.array([2, 2]))
        opt = optional_ops.Optional.from_value((st_0, st_1))
        self.assertTrue(self.evaluate(opt.has_value()))
        val_0, val_1 = opt.get_value()
        for expected, actual in [(st_0, val_0), (st_1, val_1)]:
            self.assertAllEqual(expected.indices,
                                self.evaluate(actual.indices))
            self.assertAllEqual(expected.values, self.evaluate(actual.values))
            self.assertAllEqual(expected.dense_shape,
                                self.evaluate(actual.dense_shape))

    def testFromNone(self):
        value_structure = structure.TensorStructure(dtypes.float32, [])
        opt = optional_ops.Optional.none_from_structure(value_structure)
        self.assertTrue(
            opt.value_structure.is_compatible_with(value_structure))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.float32, [1])))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.int32, [])))
        self.assertFalse(self.evaluate(opt.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(opt.get_value())

    def testAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt1 = optional_ops.Optional.from_value((1.0, 2.0))
                opt2 = optional_ops.Optional.from_value((3.0, 4.0))

                add_tensor = math_ops.add_n(
                    [opt1._variant_tensor, opt2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt1.value_structure)
                self.assertAllEqual(self.evaluate(add_opt.get_value()),
                                    (4.0, 6.0))

                # Without value
                opt_none1 = optional_ops.Optional.none_from_structure(
                    opt1.value_structure)
                opt_none2 = optional_ops.Optional.none_from_structure(
                    opt2.value_structure)
                add_tensor = math_ops.add_n(
                    [opt_none1._variant_tensor, opt_none2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt_none1.value_structure)
                self.assertFalse(self.evaluate(add_opt.has_value()))

    def testNestedAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value([1, 2.0])
                opt2 = optional_ops.Optional.from_value([3, 4.0])
                opt3 = optional_ops.Optional.from_value(
                    (5.0, opt1._variant_tensor))
                opt4 = optional_ops.Optional.from_value(
                    (6.0, opt2._variant_tensor))

                add_tensor = math_ops.add_n(
                    [opt3._variant_tensor, opt4._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt3.value_structure)
                self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)

                inner_add_opt = optional_ops._OptionalImpl(
                    add_opt.get_value()[1], opt1.value_structure)
                self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])

    def testZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt = optional_ops.Optional.from_value((1.0, 2.0))
                zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt.value_structure)
                self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
                                    (0.0, 0.0))

                # Without value
                opt_none = optional_ops.Optional.none_from_structure(
                    opt.value_structure)
                zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(
                    zeros_tensor, opt_none.value_structure)
                self.assertFalse(self.evaluate(zeros_opt.has_value()))

    def testNestedZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value(1.0)
                opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)

                zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt2.value_structure)
                inner_zeros_opt = optional_ops._OptionalImpl(
                    zeros_opt.get_value(), opt1.value_structure)
                self.assertEqual(self.evaluate(inner_zeros_opt.get_value()),
                                 0.0)

    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))

    def testNestedCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))
            nested_optional = optional_ops.Optional.from_value(
                (optional_with_value._variant_tensor,
                 optional_none._variant_tensor, 1.0))

        with ops.device("/gpu:0"):
            gpu_nested_optional = optional_ops._OptionalImpl(
                array_ops.identity(nested_optional._variant_tensor),
                nested_optional.value_structure)

            gpu_nested_optional_has_value = gpu_nested_optional.has_value()
            gpu_nested_optional_values = gpu_nested_optional.get_value()

        self.assertTrue(self.evaluate(gpu_nested_optional_has_value))

        inner_with_value = optional_ops._OptionalImpl(
            gpu_nested_optional_values[0], optional_with_value.value_structure)

        inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1],
                                                optional_none.value_structure)

        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(inner_with_value.get_value()))
        self.assertFalse(self.evaluate(inner_none.has_value()))
        self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))

    def _assertElementValueEqual(self, expected, actual):
        if isinstance(expected, dict):
            self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
            for k in expected.keys():
                self._assertElementValueEqual(expected[k], actual[k])
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.assertAllEqual(expected.indices, actual.indices)
            self.assertAllEqual(expected.values, actual.values)
            self.assertAllEqual(expected.dense_shape, actual.dense_shape)
        else:
            self.assertAllEqual(expected, actual)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 1]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[10, 10]),
         structure.SparseTensorStructure(dtypes.int32, [10, 10])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.TensorStructure(dtypes.string, [1]),
                  structure.TensorStructure(dtypes.string, []))
        }),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testOptionalStructure(self, tf_value_fn, expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            structure.are_compatible(opt.value_structure,
                                     expected_value_structure))

        opt_structure = structure.type_spec_from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
        self.assertTrue(
            structure.are_compatible(opt_structure._value_structure,
                                     expected_value_structure))
        self.assertEqual([dtypes.variant], opt_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.type_spec_from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self._assertElementValueEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self._assertElementValueEqual(
                self.evaluate(tf_value),
                self.evaluate(round_trip_opt.get_value()))

    @parameterized.named_parameters(
        ("Tensor", np.array([1, 2, 3], dtype=np.int32),
         lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
        ("SparseTensor",
         sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                         values=np.array([-1., 1.],
                                                         dtype=np.float32),
                                         dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                            values=[37.0, 42.0],
                                            dense_shape=[2, 2]), False),
        ("Nest", {
            "a":
            np.array([1, 2, 3], dtype=np.int32),
            "b":
            sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                            values=np.array([-1., 1.],
                                                            dtype=np.float32),
                                            dense_shape=[2, 2])
        }, lambda: {
            "a":
            constant_op.constant([4, 5, 6], dtype=dtypes.int32),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                       values=[37.0, 42.0],
                                       dense_shape=[2, 2])
        }, False),
    )
    def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                      works_on_gpu):
        if not works_on_gpu and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

        if context.executing_eagerly():
            iterator = dataset_ops.make_one_shot_iterator(ds)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertIsInstance(next_elem, optional_ops.Optional)
                self.assertTrue(
                    structure.are_compatible(
                        next_elem.value_structure,
                        structure.type_spec_from_value(tf_value_fn())))
                self.assertTrue(next_elem.has_value())
                self._assertElementValueEqual(np_value, next_elem.get_value())
            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertFalse(self.evaluate(next_elem.has_value()))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_elem.get_value())
        else:
            iterator = dataset_ops.make_initializable_iterator(ds)
            next_elem = iterator_ops.get_next_as_optional(iterator)
            self.assertIsInstance(next_elem, optional_ops.Optional)
            self.assertTrue(
                structure.are_compatible(
                    next_elem.value_structure,
                    structure.type_spec_from_value(tf_value_fn())))
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError. This is only relevant in graph mode.
            elem_has_value_t = next_elem.has_value()
            elem_value_t = next_elem.get_value()
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_value_t)
            # Now we initialize the iterator.
            self.evaluate(iterator.initializer)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                elem_has_value, elem_value = self.evaluate(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self._assertElementValueEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(elem_value_t)

    def testFunctionBoundaries(self):
        @def_function.function
        def get_optional():
            x = constant_op.constant(1.0)
            opt = optional_ops.Optional.from_value(x)
            # TODO(skyewm): support returning Optionals from functions?
            return opt._variant_tensor

        # TODO(skyewm): support Optional arguments?
        @def_function.function
        def consume_optional(opt_tensor):
            value_structure = structure.TensorStructure(dtypes.float32, [])
            opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
            return opt.get_value()

        opt_tensor = get_optional()
        val = consume_optional(opt_tensor)
        self.assertEqual(self.evaluate(val), 1.0)

    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(opt):
            trace_count[0] += 1
            return opt.get_value()

        opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0))
        opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0))

        for _ in range(10):
            self.assertEqual(self.evaluate(f(opt1)), 37.0)
            self.assertEqual(self.evaluate(f(opt2)), 42.0)
            self.assertEqual(trace_count[0], 1)
Beispiel #27
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    test_util.TensorFlowTestCase):

    # pylint: disable=g-long-lambda,protected-access
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
         [dtypes.float32], [[]]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
         ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
        ("Nested_0", lambda:
         (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple,
         [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_1", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, dict, [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_2", lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None,
                                                                    None]),
    )
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.type_spec_from_value(value)
        self.assertIsInstance(s, expected_structure)
        flat_types = structure.get_flat_tensor_types(s)
        self.assertEqual(expected_types, flat_types)
        flat_shapes = structure.get_flat_tensor_shapes(s)
        self.assertLen(flat_shapes, len(expected_shapes))
        for expected, actual in zip(expected_shapes, flat_shapes):
            if expected is None:
                self.assertEqual(actual.ndims, None)
            else:
                self.assertEqual(actual.as_list(), expected)

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [
             ragged_factory_ops.constant([[1, 2], [3, 4], []]),
             ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
         ], lambda: [
             ragged_factory_ops.constant(1),
             ragged_factory_ops.constant([1, 2]),
             ragged_factory_ops.constant([[1], [2]]),
             ragged_factory_ops.constant([["a", "b"]]),
         ]),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.type_spec_from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                structure.are_compatible(
                    s, structure.type_spec_from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                structure.are_compatible(
                    s, structure.type_spec_from_value(incompatible_value)))

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
         lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda:
         ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
         lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
         lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )  # pyformat: disable
    def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                       *not_equal_value_fns):
        # pylint: disable=g-generic-assert
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        self.assertEqual(s1, s1)  # check __eq__ operator.
        self.assertEqual(s1, s2)  # check __eq__ operator.
        self.assertFalse(s1 != s1)  # check __ne__ operator.
        self.assertFalse(s1 != s2)  # check __ne__ operator.
        for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
        for value_fn in not_equal_value_fns:
            s3 = structure.type_spec_from_value(value_fn())
            self.assertNotEqual(s1, s3)  # check __ne__ operator.
            self.assertNotEqual(s2, s3)  # check __ne__ operator.
            self.assertFalse(s1 == s3)  # check __eq_ operator.
            self.assertFalse(s2 == s3)  # check __eq_ operator.

    @parameterized.named_parameters(
        ("RaggedTensor_RaggedRank",
         structure.RaggedTensorStructure(dtypes.int32, None, 1),
         structure.RaggedTensorStructure(dtypes.int32, None, 2)),
        ("RaggedTensor_Shape",
         structure.RaggedTensorStructure(dtypes.int32, [3, None], 1),
         structure.RaggedTensorStructure(dtypes.int32, [5, None], 1)),
        ("RaggedTensor_DType",
         structure.RaggedTensorStructure(dtypes.int32, None, 1),
         structure.RaggedTensorStructure(dtypes.float32, None, 1)),
    )
    def testRaggedStructureInequality(self, s1, s2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(s1, s2)  # check __ne__ operator.
        self.assertFalse(s1 == s2)  # check __eq__ operator.

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )
    def testHash(self, value1_fn, value2_fn, value3_fn):
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        s3 = structure.type_spec_from_value(value3_fn())
        for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2),
                              nest.flatten(s3)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
            self.assertNotEqual(hash(c1), hash(c3))
            self.assertNotEqual(hash(c2), hash(c3))

    @parameterized.named_parameters(
        (
            "Tensor",
            lambda: constant_op.constant(37.0),
        ),
        (
            "SparseTensor",
            lambda: sparse_tensor.SparseTensor(
                indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
        ),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (
            "RaggedTensor",
            lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
        ),
        (
            "Nested_0",
            lambda: {
                "a": constant_op.constant(37.0),
                "b": constant_op.constant([1, 2, 3])
            },
        ),
        (
            "Nested_1",
            lambda: {
                "a":
                constant_op.constant(37.0),
                "b": (sparse_tensor.SparseTensor(
                    indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                      sparse_tensor.SparseTensor(
                          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
            },
        ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def preserveStaticShape(self):
        rt = ragged_factory_ops.constant([[1, 2], [], [3]])
        rt_s = structure.type_spec_from_value(rt)
        rt_after = structure.from_tensor_list(
            rt_s, structure.to_tensor_list(rt_s, rt))
        self.assertEqual(rt_after.row_splits.shape.as_list(),
                         rt.row_splits.shape.as_list())
        self.assertEqual(rt_after.values.shape.as_list(), [None])

        st = sparse_tensor.SparseTensor(indices=[[3, 4]],
                                        values=[-1],
                                        dense_shape=[4, 5])
        st_s = structure.type_spec_from_value(st)
        st_after = structure.from_tensor_list(
            st_s, structure.to_tensor_list(st_s, st))
        self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
        self.assertEqual(st_after.values.shape.as_list(), [None])
        self.assertEqual(st_after.dense_shape.shape.as_list(),
                         st.dense_shape.shape.as_list())

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegexp(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegexp(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegexp(ValueError,
                                     "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("TensorArray_0", dtypes.int32,
         tensor_shape.as_shape([None, True, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
        ("TensorArray_1", dtypes.int32,
         tensor_shape.as_shape([True, None, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
        ("TensorArray_2", dtypes.int32,
         tensor_shape.as_shape([True, False, 2, 2
                                ]), tensor_array_ops.TensorArray,
         structure.TensorArrayStructure(
             dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
        ("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None),
         structure.RaggedTensorStructure(dtypes.int32, [2, None], 1),
         structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)),
        ("Nested", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertEqual(actual_structure, expected_structure)

    def testNestedNestedStructure(self):
        s = (structure.TensorStructure(dtypes.int64, []),
             (structure.TensorStructure(dtypes.float32, []),
              structure.TensorStructure(dtypes.string, [])))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, []), 32,
         structure.TensorStructure(dtypes.float32, [32])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, []), None,
         structure.TensorStructure(dtypes.float32, [None])),
        ("SparseTensor", structure.SparseTensorStructure(
            dtypes.float32, [None]), 32,
         structure.SparseTensorStructure(dtypes.float32, [32, None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [4]), None,
         structure.SparseTensorStructure(dtypes.float32, [None, 4])),
        ("RaggedTensor",
         structure.RaggedTensorStructure(dtypes.float32, [2, None], 1), 32,
         structure.RaggedTensorStructure(dtypes.float32, [32, 2, None], 2)),
        ("RaggedTensorUnknown",
         structure.RaggedTensorStructure(dtypes.float32, [4, None], 1), None,
         structure.RaggedTensorStructure(dtypes.float32, [None, 4, None], 2)),
        ("Nested", {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }, 128, {
            "a":
            structure.TensorStructure(dtypes.float32, [128]),
            "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                  structure.TensorStructure(dtypes.string, [128]))
        }),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = nest.map_structure(
            lambda component_spec: component_spec._batch(batch_size),
            element_structure)
        self.assertEqual(batched_structure, expected_batched_structure)

    @parameterized.named_parameters(
        ("Tensor", structure.TensorStructure(dtypes.float32, [32]),
         structure.TensorStructure(dtypes.float32, [])),
        ("TensorUnknown", structure.TensorStructure(dtypes.float32, [None]),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor",
         structure.SparseTensorStructure(dtypes.float32, [32, None]),
         structure.SparseTensorStructure(dtypes.float32, [None])),
        ("SparseTensorUnknown",
         structure.SparseTensorStructure(dtypes.float32, [None, 4]),
         structure.SparseTensorStructure(dtypes.float32, [4])),
        ("RaggedTensor",
         structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2),
         structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
        ("RaggedTensorUnknown",
         structure.RaggedTensorStructure(dtypes.float32, [None, None, None],
                                         2),
         structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
        ("Nested", {
            "a":
            structure.TensorStructure(dtypes.float32, [128]),
            "b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
                  structure.TensorStructure(dtypes.string, [None]))
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                  structure.TensorStructure(dtypes.string, []))
        }),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = nest.map_structure(
            lambda component_spec: component_spec._unbatch(),
            element_structure)
        self.assertEqual(unbatched_structure, expected_unbatched_structure)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
         lambda: ragged_factory_ops.constant([[1]])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            self.assertValuesEqual(expected, actual)
Beispiel #28
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  def testAsFunctionWithMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

  def testAsFunctionWithMapInFlatMap(self):
    if not context.executing_eagerly():
      self.skipTest("Only works executing eagerly")
    with ops.device("CPU"):
      original_dataset = dataset_ops.Dataset.range(5).flat_map(
          lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
      fn = original_dataset._trace_variant_creation()
      variant = fn()

      revived_dataset = _RevivedDataset(
          variant, original_dataset._element_structure)
      self.assertDatasetProduces(revived_dataset, list(original_dataset))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertLen(dataset_fn()._inputs(), num_inputs)

  @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEmpty(dataset_fn._inputs())

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testNoWarnings(self):
    with test.mock.patch.object(warnings, "warn") as mock_log:
      dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
      dataset_fn(dataset_ops.Dataset.range(10))
      self.assertEmpty(mock_log.call_args_list)

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testFunctions(self):
    dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
    self.assertLen(dataset._functions(), 1)

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testDatasetStructure(self, tf_value_fn, expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset_ops.make_one_shot_iterator(dataset)
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        dataset = dataset.batch(2)

  @parameterized.named_parameters(
      ("Async", context.ASYNC),
      ("Sync", context.SYNC),
  )
  def testDatasetEagerIteration(self, execution_mode):
    with context.eager_mode(), context.execution_mode(execution_mode):
      val = 0
      dataset = dataset_ops.Dataset.range(10)
      for foo in dataset:
        self.assertEqual(val, foo.numpy())
        val += 1

  def testDatasetAsFunctionArgument(self):

    @def_function.function
    def _uses_dataset(d):
      accumulator = array_ops.zeros([], dtype=dtypes.int64)
      for value in d:
        accumulator += value
      return accumulator

    with ops.device("CPU"):
      first_dataset = dataset_ops.Dataset.range(10)
      self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
      second_dataset = dataset_ops.Dataset.range(11)
      self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
      first_concrete = _uses_dataset.get_concrete_function(first_dataset)
      # The dataset should not be a captured input
      self.assertEmpty(first_concrete.graph.captures)
      # The two datasets have the same structure and so should re-use a trace.
      self.assertIs(first_concrete,
                    _uses_dataset.get_concrete_function(second_dataset))
      # With a different structure we should use a different trace.
      self.assertIsNot(
          first_concrete,
          _uses_dataset.get_concrete_function(
              dataset_ops.Dataset.zip((first_dataset, second_dataset))))

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(ds):
      trace_count[0] += 1
      counter = np.int64(0)
      for elem in ds:
        counter += elem
      return counter

    dataset = dataset_ops.Dataset.range(5)
    dataset2 = dataset_ops.Dataset.range(10)

    for _ in range(10):
      self.assertEqual(self.evaluate(f(dataset)), 10)
      self.assertEqual(self.evaluate(f(dataset2)), 45)
      self.assertEqual(trace_count[0], 1)
Beispiel #29
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testAsSerializedGraph(self):
    dataset = dataset_ops.Dataset.range(10)
    graph = graph_pb2.GraphDef().FromString(
        self.evaluate(dataset._as_serialized_graph()))
    self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))

  @staticmethod
  def make_apply_fn(dataset):

    def apply_fn(dataset):

      def _apply_fn(dataset):
        return dataset.cache()

      return dataset.apply(_apply_fn)

    return apply_fn

  @staticmethod
  def make_gen():

    def gen():
      yield 42

    return gen

  @staticmethod
  def make_interleave_fn(dataset, num_parallel_calls=None):

    def interleave_fn(dataset):
      return dataset.interleave(
          lambda x: dataset_ops.Dataset.range(0),
          cycle_length=2,
          num_parallel_calls=num_parallel_calls)

    return interleave_fn

  @parameterized.named_parameters(
      ("FixedLengthRecord",
       lambda: readers.FixedLengthRecordDataset("", 42)),
      ("FromGenerator",
       lambda: dataset_ops.Dataset.from_generator(
           DatasetTest.make_gen(), dtypes.int32),
       1),
      ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
      ("Range", lambda: dataset_ops.Dataset.range(10)),
      ("TextLine", lambda: readers.TextLineDataset("")),
      ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
  )
  def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
    self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

  def testDatasetComplexSourceInputs(self):
    dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
        sparse_tensor.SparseTensor(
            indices=np.array([[0, 0], [1, 0], [2, 0]]),
            values=np.array([0, 0, 0]),
            dense_shape=np.array([3, 1])))
    self.assertEqual(0, len(dataset_fn._inputs()))

  @parameterized.named_parameters(
      ("Batch",
       lambda x: x.batch(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Cache",
       lambda x: x.cache(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Filter",
       lambda x: x.filter(lambda x: True),
       lambda: dataset_ops.Dataset.range(0)),
      ("FlatMap",
       lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
       lambda: dataset_ops.Dataset.range(0)),
      ("Map",
       lambda x: x.map(lambda x: x),
       lambda: dataset_ops.Dataset.range(0)),
      ("PaddedBatch",
       lambda x: x.padded_batch(10, []),
       lambda: dataset_ops.Dataset.range(0)),
      ("ParallelMap",
       lambda x: x.map(lambda x: x, num_parallel_calls=2),
       lambda: dataset_ops.Dataset.range(0)),
      ("Repeat",
       lambda x: x.repeat(),
       lambda: dataset_ops.Dataset.range(0)),
      ("Shuffle",
       lambda x: x.shuffle(10),
       lambda: dataset_ops.Dataset.range(0)),
      ("Skip",
       lambda x: x.skip(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Take",
       lambda x: x.take(1),
       lambda: dataset_ops.Dataset.range(0)),
      ("Window",
       lambda x: x.window(10),
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
    input_dataset = input_dataset_fn()
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  def testUnaryTransformationInputsApply(self):
    input_dataset = dataset_ops.Dataset.range(0)
    dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("ParallelInterleave",
       [lambda: dataset_ops.Dataset.range(0), 2],
       lambda: dataset_ops.Dataset.range(0)),
      ("Interleave",
       [lambda: dataset_ops.Dataset.range(0), None],
       lambda: dataset_ops.Dataset.range(0)),
  )
  def testUnaryTransformationInputsWithInterleaveFn(
      self, interleave_fn_args, input_dataset_fn):
    input_dataset = input_dataset_fn()
    dataset_fn = self.make_interleave_fn(*interleave_fn_args)
    self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

  @parameterized.named_parameters(
      ("Concatenate", lambda x, y: x.concatenate(y),
       lambda: dataset_ops.Dataset.range(0),
       lambda: dataset_ops.Dataset.range(1)))
  def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
    input1 = input1_fn()
    input2 = input2_fn()
    self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())

  @parameterized.named_parameters(
      ("ZipOne",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0))),
      ("ZipNest",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                (dataset_ops.Dataset.range(1),
                 dataset_ops.Dataset.range(2)))),
      ("ZipTuple",
       dataset_ops.Dataset.zip,
       lambda: (dataset_ops.Dataset.range(0),
                dataset_ops.Dataset.range(1))),
  )
  def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
    input_datasets = input_datasets_fn()
    self.assertEqual(
        nest.flatten(input_datasets),
        dataset_fn(input_datasets)._inputs())

  def testCollectInputs(self):
    ds1 = dataset_ops.Dataset.range(0)
    ds2 = ds1.concatenate(ds1)
    ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

    inputs = []
    queue = [ds3]
    while queue:
      ds = queue[0]
      queue = queue[1:]
      queue.extend(ds._inputs())
      inputs.append(ds)

    self.assertEqual(5, inputs.count(ds1))
    self.assertEqual(2, inputs.count(ds2))
    self.assertEqual(1, inputs.count(ds3))

  # TODO(b/119882922): use-after-free bug in eager mode.
  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       structure.TensorStructure(dtypes.float32, [])),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[1]),
       structure.SparseTensorStructure(dtypes.int32, [1])),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
       structure.NestedStructure({
           "a": structure.TensorStructure(dtypes.float32, []),
           "b": (structure.TensorStructure(dtypes.string, [1]),
                 structure.TensorStructure(dtypes.string, []))})),
      ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
          constant_op.constant([1, 2, 3])),
       dataset_ops.DatasetStructure(
           structure.TensorStructure(dtypes.int32, []))),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalStructure(
           structure.TensorStructure(dtypes.float32, []))),
  )
  def testSkipEagerDatasetStructure(self, tf_value_fn,
                                    expected_element_structure):
    dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
    dataset_structure = structure.Structure.from_value(dataset)
    self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

    # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
    # the element structure.
    self.assertTrue(expected_element_structure.is_compatible_with(
        dataset_structure._element_structure))
    self.assertTrue(dataset_structure._element_structure.is_compatible_with(
        expected_element_structure))

    self.assertEqual([dtypes.variant], dataset_structure._flat_types)
    self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)

    # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_dataset = dataset_structure._from_tensor_list(
        dataset_structure._to_tensor_list(dataset))

    value = tf_value_fn()

    if isinstance(value, dataset_ops.Dataset):
      self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
    elif isinstance(value, optional_ops.Optional):
      self.assertDatasetProduces(
          round_trip_dataset.map(lambda opt: opt.get_value()),
          [self.evaluate(value.get_value())],
          requires_initialization=True)
    else:
      self.assertDatasetProduces(
          round_trip_dataset, [self.evaluate(tf_value_fn())],
          requires_initialization=True)

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShot(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorOneShotSimple(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      with test.mock.patch.object(logging, "warning") as mock_log:
        _ = dataset.make_one_shot_iterator()
        self.assertRegexpMatches(
            str(mock_log.call_args), "Please ensure that all datasets in the "
            "pipeline are created in the same graph as the iterator.")

  @test_util.run_deprecated_v1
  def testSkipEagerSameGraphErrorInitializable(self):
    dataset = dataset_ops.Dataset.range(10)
    with ops.Graph().as_default():
      dataset = dataset.batch(2)
      with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
        _ = dataset.make_initializable_iterator()
class StructureTest(test.TestCase, parameterized.TestCase):

    # NOTE(mrry): The arguments must be lifted into lambdas because otherwise they
    # will be executed before the (eager- or graph-mode) test environment has been
    # set up.
    # pylint: disable=g-long-lambda,protected-access
    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), structure.TensorStructure,
         [dtypes.float32], [[]]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         structure.SparseTensorStructure, [dtypes.variant], [None]),
        (lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
         structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]
                                                                       ]),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, structure.NestedStructure,
         [dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.Structure.from_value(value)
        self.assertIsInstance(s, expected_structure)
        self.assertEqual(expected_types, s._flat_types)
        for expected, actual in zip(expected_shapes, s._flat_shapes):
            self.assertTrue(actual.is_compatible_with(expected))
            self.assertTrue(
                tensor_shape.as_shape(expected).is_compatible_with(actual))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.Structure.from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                s.is_compatible_with(
                    structure.Structure.from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                s.is_compatible_with(
                    structure.Structure.from_value(incompatible_value)))

    @parameterized.parameters(
        (lambda: constant_op.constant(37.0), ),
        (lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), ),
        (lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, ),
        (lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.Structure.from_value(value)
        before = self.evaluate(value)
        after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.Structure.from_value(value_tensor)
        flat_tensor = s_tensor._to_tensor_list(value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
        flat_sparse_tensor = s_sparse_tensor._to_tensor_list(
            value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.Structure.from_value(value_nest)
        flat_nest = s_nest._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_sparse_tensor)
        with self.assertRaisesRegexp(
                ValueError,
                r"Value \{.*\} is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            s_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(TypeError,
                                     "Input must be a SparseTensor"):
            s_sparse_tensor._to_tensor_list(value_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure.*TensorStructure"):
            s_nest._to_tensor_list(value_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
            s_tensor._from_tensor_list(flat_sparse_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "TensorStructure corresponds to a single tf.Tensor."):
            s_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_sparse_tensor._from_tensor_list(flat_nest)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_tensor)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 1."):
            s_nest._from_tensor_list(flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructre a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.Structure.from_value(value_0)
        flat_s_0 = s_0._to_tensor_list(value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.Structure.from_value(value_1)
        flat_s_1 = s_1._to_tensor_list(value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.Structure.from_value(value_2)
        flat_s_2 = s_2._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.* not compatible with the nested structure "
                ".*TensorStructure"):
            s_0._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.* not compatible with the nested structure "
                ".*SparseTensorStructure"):
            s_1._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensor.*SparseTensor.* not compatible with the "
                "nested structure .*TensorStructure"):
            s_0._to_tensor_list(value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegexp(
                ValueError,
                "Tensor.*Tensor.* not compatible with the nested structure "
                ".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_0)

        with self.assertRaisesRegexp(
                ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
                "not compatible with the nested structure .*"
                "(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
                "SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"
        ):
            s_2._to_tensor_list(value_1)

        with self.assertRaisesRegexp(
                ValueError,
                r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
            s_0._from_tensor_list(flat_s_1)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_0._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "SparseTensorStructure corresponds to a single tf.variant "
                "vector of length 3."):
            s_1._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 2 flat values in NestedStructure but got 3."):
            s_1._from_tensor_list(flat_s_2)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_0)

        with self.assertRaisesRegexp(
                ValueError,
                "Expected 3 flat values in NestedStructure but got 2."):
            s_2._from_tensor_list(flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.scalar(), ops.Tensor,
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", dtypes.int32, tensor_shape.matrix(
            2, 2), sparse_tensor.SparseTensor,
         structure.SparseTensorStructure(dtypes.int32, [2, 2])),
        ("Nest", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.scalar(),
            "b": (tensor_shape.matrix(2, 2), tensor_shape.scalar())
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.SparseTensorStructure(dtypes.int32, [2, 2]),
                   structure.TensorStructure(dtypes.string, []))
         })),
    )
    def testFromLegacyStructure(self, output_types, output_shapes,
                                output_classes, expected_structure):
        actual_structure = structure.Structure._from_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertTrue(
            expected_structure.is_compatible_with(actual_structure))
        self.assertTrue(
            actual_structure.is_compatible_with(expected_structure))