Exemple #1
0
 def compute_output_signature(self, input_spec):
     if not isinstance(input_spec, (tuple, list)):
         output_shape = self.compute_output_shape(input_spec.shape)
         output_dtype = dtypes.int64
         if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
             return sparse_tensor.SparseTensorSpec(shape=output_shape,
                                                   dtype=output_dtype)
         else:
             return tensor_spec.TensorSpec(shape=output_shape,
                                           dtype=output_dtype)
     input_shapes = [x.shape for x in input_spec]
     output_shape = self.compute_output_shape(input_shapes)
     if any([
             isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
             for inp_spec in input_spec
     ]):
         return tensor_spec.TensorSpec(shape=output_shape,
                                       dtype=dtypes.int64)
     elif any([
             isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
             for inp_spec in input_spec
     ]):
         return sparse_tensor.SparseTensorSpec(shape=output_shape,
                                               dtype=dtypes.int64)
     return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
Exemple #2
0
    def testConstruction(self):
        spec1 = sparse_tensor.SparseTensorSpec()
        self.assertEqual(spec1.shape.rank, None)
        self.assertEqual(spec1.dtype, dtypes.float32)

        spec2 = sparse_tensor.SparseTensorSpec([None, None], dtypes.string)
        self.assertEqual(spec2.shape.as_list(), [None, None])
        self.assertEqual(spec2.dtype, dtypes.string)
    def testTypeSpec(self, distribution, enable_get_next_as_optional):
        if not tf2.enabled():
            self.skipTest("DistributedIterator has CompositeTensor support in "
                          "TF 2.0 only.")
        ctx = distribute_lib.InputContext()
        batch_size = ctx.get_per_replica_batch_size(8)
        # Use 20 which isn't divisible by 8 to test partial batch behavior.
        row_lengths = np.mod(np.arange(20), 4).astype(np.int64)
        ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths(
            np.repeat(np.arange(20, dtype=np.float32), row_lengths),
            row_lengths)
        dataset = dataset_ops.DatasetV2.from_tensor_slices({
            "dense":
            ragged_tensor.to_tensor(),
            "ragged":
            ragged_tensor,
            "sparse":
            ragged_tensor.to_sparse(),
        })
        dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
        dataset = dataset.batch(batch_size)

        distribution.extended.experimental_enable_get_next_as_optional = (
            enable_get_next_as_optional)

        dist_dataset = distribution.experimental_distribute_dataset(dataset)
        with distribution.scope():
            iterator = iter(dist_dataset)
            _check_type_spec_structure(iterator)

        spec = iterator._type_spec
        self.assertEqual(spec._input_workers, iterator._input_workers)
        self.assertEqual(
            spec._element_spec, {
                "sparse":
                values.PerReplicaSpec(
                    sparse_tensor.SparseTensorSpec(
                        tensor_shape.TensorShape([None, 3]), dtypes.float32),
                    sparse_tensor.SparseTensorSpec(
                        tensor_shape.TensorShape([None, 3]), dtypes.float32)),
                "dense":
                values.PerReplicaSpec(
                    tensor_spec.TensorSpec(
                        shape=(None, 3), dtype=dtypes.float32, name=None),
                    tensor_spec.TensorSpec(
                        shape=(None, 3), dtype=dtypes.float32, name=None)),
                "ragged":
                values.PerReplicaSpec(
                    ragged_tensor_lib.RaggedTensorSpec(
                        tensor_shape.TensorShape([None, None]), dtypes.float32,
                        1, dtypes.int64),
                    ragged_tensor_lib.RaggedTensorSpec(
                        tensor_shape.TensorShape([None, None]), dtypes.float32,
                        1, dtypes.int64))
            })
Exemple #4
0
def _test_convert_legacy_structure_combinations():
    cases = [(dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor,
              tensor_spec.TensorSpec([], dtypes.float32)),
             (dtypes.int32, tensor_shape.TensorShape([2, 2]),
              sparse_tensor.SparseTensor,
              sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
             (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=None,
                                               infer_shape=True)),
             (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=True,
                                               infer_shape=None)),
             (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]),
              tensor_array_ops.TensorArray,
              tensor_array_ops.TensorArraySpec([2, 2],
                                               dtypes.int32,
                                               dynamic_size=True,
                                               infer_shape=False)),
             (dtypes.int32, tensor_shape.TensorShape([2, None]),
              ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
              ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
             ({
                 "a": dtypes.float32,
                 "b": (dtypes.int32, dtypes.string)
             }, {
                 "a":
                 tensor_shape.TensorShape([]),
                 "b":
                 (tensor_shape.TensorShape([2,
                                            2]), tensor_shape.TensorShape([]))
             }, {
                 "a": ops.Tensor,
                 "b": (sparse_tensor.SparseTensor, ops.Tensor)
             }, {
                 "a":
                 tensor_spec.TensorSpec([], dtypes.float32),
                 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                       tensor_spec.TensorSpec([], dtypes.string))
             })]

    def reduce_fn(x, y):
        output_types, output_shapes, output_classes, expected_structure = y
        return x + combinations.combine(output_types=output_types,
                                        output_shapes=output_shapes,
                                        output_classes=output_classes,
                                        expected_structure=expected_structure)

    return functools.reduce(reduce_fn, cases, [])
Exemple #5
0
 def testEncodeDecodeSparseTensorSpec(self):
     structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)]
     self.assertTrue(self._coder.can_encode(structure))
     encoded = self._coder.encode_structure(structure)
     expected_pbtxt = r"""
   list_value {
     values {
       type_spec_value {
         type_spec_class: SPARSE_TENSOR_SPEC
         type_spec_class_name: 'SparseTensorSpec'
         num_flat_components: 3
         type_state {
           tuple_value {
             # spec._shape
             values {
               tensor_shape_value {
                 dim { size: 10 }
                 dim { size: 20 }
               }
             }
             # spec._dtype
             values { tensor_dtype_value: DT_FLOAT }
           }
         }
       }
     }
   }
 """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
     self.assertEqual(expected, encoded)
     decoded = self._coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
Exemple #6
0
 def common_spec(x, y):
   common_shape = get_common_shape(x.shape, y.shape)
   if isinstance(x, sparse_tensor.SparseTensorSpec):
     return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
   elif isinstance(x, ragged_tensor.RaggedTensorSpec):
     return ragged_tensor.RaggedTensorSpec(common_shape, x.dtype)
   return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
Exemple #7
0
 def compute_output_signature(self, input_spec):
   output_shape = self.compute_output_shape(input_spec.shape.as_list())
   if self.sparse:
     return sparse_tensor.SparseTensorSpec(
         shape=output_shape, dtype=dtypes.int64)
   else:
     return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.int64)
Exemple #8
0
def _optional_spec_test_combinations():
    # pylint: disable=g-long-lambda
    cases = [
        ("Dense", lambda: constant_op.constant(37.0),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("Sparse", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 1]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[10, 10]),
         sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (
                tensor_spec.TensorSpec([1], dtypes.string),
                tensor_spec.TensorSpec([], dtypes.string),
            )
        }),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([],
                                                          dtypes.float32))),
    ]

    def reduce_fn(x, y):
        name, value_fn, expected_structure = y
        return x + combinations.combine(
            tf_value_fn=combinations.NamedObject(name, value_fn),
            expected_value_structure=expected_structure)

    return functools.reduce(reduce_fn, cases, [])
Exemple #9
0
    def __init__(self, input_dataset, batch_size, row_shape):
        """See `Dataset.dense_to_sparse_batch()` for more details."""
        if not isinstance(dataset_ops.get_legacy_output_types(input_dataset),
                          dtypes.DType):
            raise TypeError(
                "DenseToSparseDataset requires an input whose elements "
                "have a single component, whereas the input has %r." %
                dataset_ops.get_legacy_output_types(input_dataset))
        self._input_dataset = input_dataset
        self._batch_size = batch_size
        self._row_shape = row_shape
        self._element_spec = sparse_tensor.SparseTensorSpec(
            tensor_shape.vector(None).concatenate(self._row_shape),
            dataset_ops.get_legacy_output_types(input_dataset))

        if compat.forward_compatible(2019, 8, 3):
            variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        else:
            variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._batch_size,
                row_shape=convert.partial_shape_to_tensor(self._row_shape),
                **self._flat_structure)
        super(_DenseToSparseBatchDataset,
              self).__init__(input_dataset, variant_tensor)
Exemple #10
0
 def testSparseTensorDatasetSpec(self):
     self.checkDatasetSpec(
         sparse_tensor.SparseTensor(indices=[[0]],
                                    values=constant_op.constant(
                                        [0], dtype=dtypes.int32),
                                    dense_shape=[1]),
         sparse_tensor.SparseTensorSpec([1], dtypes.int32))
Exemple #11
0
 def compute_output_signature(self, input_spec):
   output_shape = self.compute_output_shape(input_spec.shape.as_list())
   output_dtype = dtypes.int64
   if isinstance(input_spec, sparse_tensor.SparseTensorSpec):
     return sparse_tensor.SparseTensorSpec(
         shape=output_shape, dtype=output_dtype)
   return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
Exemple #12
0
def convert_legacy_structure(output_types, output_shapes, output_classes):
    """Returns a `Structure` that represents the given legacy structure.

  This method provides a way to convert from the existing `Dataset` and
  `Iterator` structure-related properties to a `Structure` object. A "legacy"
  structure is represented by the `tf.data.Dataset.output_types`,
  `tf.data.Dataset.output_shapes`, and `tf.data.Dataset.output_classes`
  properties.

  TODO(b/110122868): Remove this function once `Structure` is used throughout
  `tf.data`.

  Args:
    output_types: A nested structure of `tf.DType` objects corresponding to
      each component of a structured value.
    output_shapes: A nested structure of `tf.TensorShape` objects
      corresponding to each component a structured value.
    output_classes: A nested structure of Python `type` objects corresponding
      to each component of a structured value.

  Returns:
    A `Structure`.

  Raises:
    TypeError: If a structure cannot be built from the arguments, because one of
      the component classes in `output_classes` is not supported.
  """
    flat_types = nest.flatten(output_types)
    flat_shapes = nest.flatten(output_shapes)
    flat_classes = nest.flatten(output_classes)
    flat_ret = []
    for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
                                                 flat_classes):
        if isinstance(flat_class, type_spec.TypeSpec):
            flat_ret.append(flat_class)
        elif issubclass(flat_class, sparse_tensor.SparseTensor):
            flat_ret.append(
                sparse_tensor.SparseTensorSpec(flat_shape, flat_type))
        elif issubclass(flat_class, ops.Tensor):
            flat_ret.append(tensor_spec.TensorSpec(flat_shape, flat_type))
        elif issubclass(flat_class, tensor_array_ops.TensorArray):
            # We sneaked the dynamic_size and infer_shape into the legacy shape.
            flat_ret.append(
                tensor_array_ops.TensorArraySpec(
                    flat_shape[2:],
                    flat_type,
                    dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
                    infer_shape=tensor_shape.dimension_value(flat_shape[1])))
        else:
            # NOTE(mrry): Since legacy structures produced by iterators only
            # comprise Tensors, SparseTensors, and nests, we do not need to
            # support all structure types here.
            raise TypeError(
                "Could not build a structure for output class {}. Make sure any "
                "component class in `output_classes` inherits from one of the "
                "following classes: `tf.TypeSpec`, `tf.sparse.SparseTensor`, "
                "`tf.Tensor`, `tf.TensorArray`.".format(flat_class.__name__))

    return nest.pack_sequence_as(output_classes, flat_ret)
 def compute_output_signature(self, input_spec):
   output_shape = self.compute_output_shape(input_spec.shape.as_list())
   output_dtype = K.floatx() if self._output_mode == TFIDF else dtypes.int64
   if self._sparse:
     return sparse_tensor.SparseTensorSpec(
         shape=output_shape, dtype=output_dtype)
   else:
     return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
 def testEncodeDataSetSpec(self):
   structure = [dataset_ops.DatasetSpec(
       {"rt": ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
        "st": sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32),
        "t": tensor_spec.TensorSpec([10, 8], dtypes.string)})]
   self.assertTrue(self._coder.can_encode(structure))
   encoded = self._coder.encode_structure(structure)
   decoded = self._coder.decode_proto(encoded)
   self.assertEqual(structure, decoded)
Exemple #15
0
 def testDatasetSpecConstructor(self):
   rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
   st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
   t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
   element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
   ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
   self.assertEqual(ds_struct._element_spec, element_spec)
   # Note: shape was automatically converted from a list to a TensorShape.
   self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
Exemple #16
0
 def testFromNumpyComponents(self):
     indices = np.array([[0], [8]])
     values = np.array([1.0, 9.0])
     dense_shape = np.array([100])
     spec = sparse_tensor.SparseTensorSpec()
     st = spec._from_components([indices, values, dense_shape])
     self.assertIsInstance(st, sparse_tensor.SparseTensorValue)
     self.assertAllEqual(st.indices, indices)
     self.assertAllEqual(st.values, values)
     self.assertAllEqual(st.dense_shape, dense_shape)
Exemple #17
0
def _test_unbatch_combinations():
    cases = [
        (tensor_spec.TensorSpec([32], dtypes.float32),
         tensor_spec.TensorSpec([], dtypes.float32)),
        (tensor_spec.TensorSpec([None], dtypes.float32),
         tensor_spec.TensorSpec([], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
         sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
        (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
         sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
        (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        ({
            "a":
            tensor_spec.TensorSpec([128], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([None], dtypes.string))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }),
    ]

    def reduce_fn(x, y):
        element_structure, expected_unbatched_structure = y
        return x + combinations.combine(
            element_structure=element_structure,
            expected_unbatched_structure=expected_unbatched_structure)

    return functools.reduce(reduce_fn, cases, [])
 def compute_output_signature(self, input_spec):
   input_shapes = [x.shape for x in input_spec]
   output_shape = self.compute_output_shape(input_shapes)
   if any(
       isinstance(inp_spec, ragged_tensor.RaggedTensorSpec)
       for inp_spec in input_spec):
     return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
   elif any(
       isinstance(inp_spec, sparse_tensor.SparseTensorSpec)
       for inp_spec in input_spec):
     return sparse_tensor.SparseTensorSpec(
         shape=output_shape, dtype=dtypes.string)
   return tensor_spec.TensorSpec(shape=output_shape, dtype=dtypes.string)
def _most_general_compatible_type(spec):
  """Returns the most general TypeSpec compatible with `spec`."""
  # TODO(edloper): Consider adding most_general_compatible_type to TypeSpec API
  if isinstance(spec, tensor_spec.TensorSpec):
    return tensor_spec.TensorSpec(None, spec.dtype)
  elif isinstance(spec, ragged_tensor.RaggedTensorSpec):
    # pylint: disable=protected-access
    return ragged_tensor.RaggedTensorSpec(None, spec._dtype, spec._ragged_rank,
                                          spec._row_splits_dtype)
  elif isinstance(spec, sparse_tensor.SparseTensorSpec):
    # pylint: disable=protected-access
    return sparse_tensor.SparseTensorSpec(None, spec.dtype)
  else:
    return spec
  def test_repr(self):
    kt = keras_tensor.KerasTensor(
        type_spec=tensor_spec.TensorSpec(shape=(1, 2, 3), dtype=dtypes.float32))
    expected_repr = "<KerasTensor: shape=(1, 2, 3) dtype=float32>"
    self.assertEqual(expected_repr, str(kt))
    self.assertEqual(expected_repr, repr(kt))

    kt = keras_tensor.KerasTensor(
        type_spec=tensor_spec.TensorSpec(shape=(2,), dtype=dtypes.int32),
        inferred_shape_value=[2, 3])
    expected_repr = (
        "<KerasTensor: shape=(2,) dtype=int32 inferred_value='[2, 3]'>")
    self.assertEqual(expected_repr, str(kt))
    self.assertEqual(expected_repr, repr(kt))

    kt = keras_tensor.KerasTensor(
        type_spec=sparse_tensor.SparseTensorSpec(
            shape=(1, 2, 3), dtype=dtypes.float32))
    expected_repr = (
        "<KerasTensor: type_spec=SparseTensorSpec("
        "TensorShape([1, 2, 3]), tf.float32)>")
    self.assertEqual(expected_repr, str(kt))
    self.assertEqual(expected_repr, repr(kt))

    with testing_utils.use_keras_tensors_scope(True):
      inp = layers.Input(shape=(3, 5))
      kt = layers.Dense(10)(inp)
      expected_repr = (
          "<KerasTensor: shape=(None, 3, 10) dtype=float32 (Symbolic value 0 "
          "from symbolic call 0 of layer 'dense')>")
      self.assertEqual(expected_repr, str(kt))
      self.assertEqual(expected_repr, repr(kt))

      kt = array_ops.reshape(kt, shape=(3, 5, 2))
      expected_repr = ("<KerasTensor: shape=(3, 5, 2) dtype=float32 (Symbolic "
                       "value 0 from symbolic call 0 of layer 'tf.reshape')>")
      self.assertEqual(expected_repr, str(kt))
      self.assertEqual(expected_repr, repr(kt))

      kts = array_ops.unstack(kt)
      for i in range(3):
        expected_repr = ("<KerasTensor: shape=(5, 2) dtype=float32 "
                         "(Symbolic value %s from symbolic call 0 "
                         "of layer 'tf.unstack')>" % i)
        self.assertEqual(expected_repr, str(kts[i]))
        self.assertEqual(expected_repr, repr(kts[i]))
  def testFromGeneratorSparseTensor(self):

    def generator():
      yield sparse_tensor.SparseTensor(
          indices=[[0, 0], [1, 2]],
          values=constant_op.constant([1, 2], dtype=dtypes.int64),
          dense_shape=[3, 4])

    dataset = dataset_ops.Dataset.from_generator(
        generator,
        output_signature=sparse_tensor.SparseTensorSpec([3, 4], dtypes.int64))

    get_next = self.getNext(dataset)

    ret = get_next()

    self.assertIsInstance(ret, sparse_tensor.SparseTensor)
    self.assertAllEqual([[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]],
                        sparse_ops.sparse_tensor_to_dense(ret))
Exemple #22
0
  def __init__(self, input_dataset, batch_size, row_shape):
    """See `Dataset.dense_to_sparse_batch()` for more details."""
    if not isinstance(
        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
      raise TypeError("`dense_to_sparse_batch` requires an input dataset whose "
                      "elements have a single component, but the given dataset "
                      "has the following component types: "
                      f"{dataset_ops.get_legacy_output_types(input_dataset)}.")
    self._input_dataset = input_dataset
    self._batch_size = batch_size
    self._row_shape = row_shape
    self._element_spec = sparse_tensor.SparseTensorSpec(
        tensor_shape.TensorShape([None]).concatenate(self._row_shape),
        dataset_ops.get_legacy_output_types(input_dataset))

    variant_tensor = ged_ops.dense_to_sparse_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._batch_size,
        row_shape=convert.partial_shape_to_tensor(self._row_shape),
        **self._flat_structure)
    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
                                                     variant_tensor)
Exemple #23
0
class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase):
    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testNoGradients(self):
        component = constant_op.constant([1.])
        side = constant_op.constant(0.)
        add = lambda x: x + side
        dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add)
        value = dataset_ops.make_one_shot_iterator(dataset).get_next()
        self.assertIsNone(gradients_impl.gradients(value, component)[0])
        self.assertIsNone(gradients_impl.gradients(value, side)[0])
        self.assertIsNone(
            gradients_impl.gradients(value, [component, side])[0])

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testCapturingStateInOneShotRaisesException(self):
        var = variables.Variable(37.0, name="myvar")
        dataset = (dataset_ops.Dataset.from_tensor_slices(
            [0.0, 1.0, 2.0]).map(lambda x: x + var))
        with self.assertRaisesRegexp(
                ValueError,
                r"`Dataset.make_one_shot_iterator\(\)` does not support "
                "datasets that capture stateful objects.+myvar"):
            dataset_ops.make_one_shot_iterator(dataset)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testOneShotIterator(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testOneShotIteratorCaptureByValue(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))
        tensor_components = tuple(
            [ops.convert_to_tensor(c) for c in components])

        def _map_fn(x, y, z):
            return math_ops.square(x), math_ops.square(y), math_ops.square(z)

        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensor_slices(tensor_components).map(
                _map_fn).repeat(14))
        get_next = iterator.get_next()

        self.assertEqual([c.shape[1:] for c in components],
                         [t.shape for t in get_next])

        with self.cached_session() as sess:
            for _ in range(14):
                for i in range(7):
                    result = sess.run(get_next)
                    for component, result_component in zip(components, result):
                        self.assertAllEqual(component[i]**2, result_component)
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(test_base.default_test_combinations())
    def testOneShotIteratorInsideContainer(self):
        components = (np.arange(7),
                      np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                      np.array(37.0) * np.arange(7))

        def within_container():
            def _map_fn(x, y, z):
                return math_ops.square(x), math_ops.square(y), math_ops.square(
                    z)

            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensor_slices(components).map(
                    _map_fn).repeat(14))
            return iterator.get_next()

        server = server_lib.Server.create_local_server()

        # Create two iterators within unique containers, and run them to
        # make sure that the resources aren't shared.
        #
        # The test below would fail if cname were the same across both
        # sessions.
        for j in range(2):
            with session.Session(server.target) as sess:
                cname = "iteration%d" % j
                with ops.container(cname):
                    get_next = within_container()

                for _ in range(14):
                    for i in range(7):
                        result = sess.run(get_next)
                        for component, result_component in zip(
                                components, result):
                            self.assertAllEqual(component[i]**2,
                                                result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testOneShotIteratorNonBlocking(self):
        dataset = dataset_ops.Dataset.from_tensors([1, 2,
                                                    3]).map(lambda x: x * x)
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        # Create a session with a single thread to ensure that the
        # one-shot iterator initializer does not deadlock.
        config = config_pb2.ConfigProto(inter_op_parallelism_threads=1,
                                        use_per_session_threads=True)
        with session.Session(config=config) as sess:
            self.assertAllEqual([1, 4, 9], sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

        # Test with multiple threads invoking the one-shot iterator concurrently.
        with session.Session(config=config) as sess:
            results = []

            def consumer_thread():
                try:
                    results.append(sess.run(next_element))
                except errors.OutOfRangeError:
                    results.append(None)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

            self.assertEqual(num_threads, len(results))
            self.assertEqual(num_threads - 1,
                             len([None for r in results if r is None]))
            self.assertAllEqual([[1, 4, 9]],
                                [r for r in results if r is not None])

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testOneShotIteratorInitializerFails(self):
        # Define a dataset whose initialization will always fail.
        dataset = dataset_ops.Dataset.from_tensors(array_ops.gather([0], [4]))
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
                sess.run(next_element)

            # Test that subsequent attempts to use the iterator also fail.
            with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
                sess.run(next_element)

        with self.cached_session() as sess:

            def consumer_thread():
                with self.assertRaisesRegexp(errors.InvalidArgumentError, ""):
                    sess.run(next_element)

            num_threads = 8
            threads = [
                self.checkedThread(consumer_thread) for _ in range(num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testSimpleSharedResource(self):
        components = (np.array(1, dtype=np.int64),
                      np.array([1, 2, 3],
                               dtype=np.int64), np.array(37.0,
                                                         dtype=np.float64))

        server = server_lib.Server.create_local_server()

        # Create two non-overlapping sessions that share the same iterator
        # resource on the same server, and verify that an action of the
        # first session (initializing the iterator) is visible in the
        # second session.
        with ops.Graph().as_default():
            iterator = dataset_ops.make_initializable_iterator(
                dataset_ops.Dataset.from_tensors(components).map(
                    lambda x, y, z: (x, y, z)),
                shared_name="shared_iterator")
            init_op = iterator.initializer
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                sess.run(init_op)
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

                # Re-initialize the iterator in the first session.
                sess.run(init_op)

        with ops.Graph().as_default():
            # Re-define the iterator manually, without defining any of the
            # functions in this graph, to ensure that we are not
            # accidentally redefining functions with the same names in the
            # new graph.
            iterator = iterator_ops.Iterator.from_structure(
                shared_name="shared_iterator",
                output_types=(dtypes.int64, dtypes.int64, dtypes.float64),
                output_shapes=([], [3], []))
            get_next = iterator.get_next()

            with session.Session(server.target) as sess:
                # Use the iterator without re-initializing in the second session.
                results = sess.run(get_next)
                for component, result_component in zip(components, results):
                    self.assertAllEqual(component, result_component)
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(get_next)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testNotInitializedError(self):
        components = (np.array(1), np.array([1, 2, 3]), np.array(37.0))
        iterator = dataset_ops.make_initializable_iterator(
            dataset_ops.Dataset.from_tensors(components))
        get_next = iterator.get_next()

        with self.cached_session() as sess:
            with self.assertRaisesRegexp(errors.FailedPreconditionError,
                                         "iterator has not been initialized"):
                sess.run(get_next)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testReinitializableIterator(self):
        dataset_3 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([1, 2, 3]))
        dataset_4 = dataset_ops.Dataset.from_tensors(
            constant_op.constant([4, 5, 6, 7]))
        iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset_3), [None])

        dataset_3_init_op = iterator.make_initializer(dataset_3)
        dataset_4_init_op = iterator.make_initializer(dataset_4)
        get_next = iterator.get_next()

        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_3),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(dataset_ops.get_legacy_output_types(dataset_4),
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(
            [None],
            dataset_ops.get_legacy_output_shapes(iterator).as_list())

        with self.cached_session() as sess:
            # The iterator is initially uninitialized.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(get_next)

            # Initialize with one dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Initialize with a different dataset.
            sess.run(dataset_4_init_op)
            self.assertAllEqual([4, 5, 6, 7], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

            # Reinitialize with the first dataset.
            sess.run(dataset_3_init_op)
            self.assertAllEqual([1, 2, 3], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testReinitializableIteratorWithFunctions(self):
        def g():
            for i in range(10):
                yield i

        iterator = iterator_ops.Iterator.from_structure(dtypes.int64, [])
        next_element = iterator.get_next()

        with self.cached_session() as sess:
            dataset_1 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_1))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

            dataset_2 = dataset_ops.Dataset.from_generator(
                g, output_types=dtypes.int64)
            sess.run(iterator.make_initializer(dataset_2))
            for expected in range(10):
                self.assertEqual(expected, sess.run(next_element))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element)

    @combinations.generate(test_base.default_test_combinations())
    def testReinitializableIteratorStaticErrors(self):
        # Non-matching structure for types and shapes.
        with self.assertRaises(TypeError):
            iterator = iterator_ops.Iterator.from_structure(
                (dtypes.int64, dtypes.float64), [None])

        # Test validation of dataset argument.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64))

        # Incompatible structure.
        with self.assertRaises(ValueError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    ((constant_op.constant([1, 2, 3], dtype=dtypes.int64), ),
                     (constant_op.constant([4., 5., 6., 7.],
                                           dtype=dtypes.float64), ))))

        # Incompatible types.
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int32),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float32))))

        # Incompatible shapes.
        iterator = iterator_ops.Iterator.from_structure(
            (dtypes.int64, dtypes.float64), ([None], []))
        with self.assertRaises(TypeError):
            iterator.make_initializer(
                dataset_ops.Dataset.from_tensors(
                    (constant_op.constant([1, 2, 3], dtype=dtypes.int64),
                     constant_op.constant([4., 5., 6., 7.],
                                          dtype=dtypes.float64))))

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testIteratorStringHandle(self):
        dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

        iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
        iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
        feedable_iterator = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
            dataset_ops.get_legacy_output_shapes(dataset_3))
        next_element = feedable_iterator.get_next()

        self.assertTrue(
            structure.are_compatible(
                dataset_ops.get_structure(dataset_3),
                dataset_ops.get_structure(feedable_iterator)))

        with self.cached_session() as sess:
            iterator_3_handle = sess.run(iterator_3.string_handle())
            iterator_4_handle = sess.run(iterator_4.string_handle())

            self.assertEqual(
                10,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                1,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                20,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                2,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                30,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            self.assertEqual(
                3,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle}))
            self.assertEqual(
                40,
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle}))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_3_handle})
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(next_element,
                         feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testIteratorStringHandleFuture(self):
        with forward_compat.forward_compatibility_horizon(2018, 8, 4):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            dataset_4 = dataset_ops.Dataset.from_tensor_slices(
                [10, 20, 30, 40])

            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

            handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            feedable_iterator = iterator_ops.Iterator.from_string_handle(
                handle_placeholder,
                dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            next_element = feedable_iterator.get_next()

            self.assertTrue(
                structure.are_compatible(
                    dataset_ops.get_structure(dataset_3),
                    dataset_ops.get_structure(feedable_iterator)))

            with self.cached_session() as sess:
                iterator_3_handle = sess.run(iterator_3.string_handle())
                iterator_4_handle = sess.run(iterator_4.string_handle())

                self.assertEqual(
                    10,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    1,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    20,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    2,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    30,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                self.assertEqual(
                    3,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_3_handle}))
                self.assertEqual(
                    40,
                    sess.run(next_element,
                             feed_dict={handle_placeholder:
                                        iterator_4_handle}))
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_3_handle})
                with self.assertRaises(errors.OutOfRangeError):
                    sess.run(next_element,
                             feed_dict={handle_placeholder: iterator_4_handle})

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testIteratorStringHandleReuseTensorObject(self):
        dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
        one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
        initializable_iterator = dataset_ops.make_initializable_iterator(
            dataset)
        structure_iterator = iterator_ops.Iterator.from_structure(
            dataset_ops.get_legacy_output_types(dataset))

        created_ops = len(ops.get_default_graph().get_operations())

        self.assertIs(one_shot_iterator.string_handle(),
                      one_shot_iterator.string_handle())
        self.assertIs(initializable_iterator.string_handle(),
                      initializable_iterator.string_handle())
        self.assertIs(structure_iterator.string_handle(),
                      structure_iterator.string_handle())

        # Assert that getting the (default) string handle creates no ops.
        self.assertEqual(created_ops,
                         len(ops.get_default_graph().get_operations()))

        # Specifying an explicit name will create a new op.
        handle_with_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo", handle_with_name.op.name)
        self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)

        handle_with_same_name = one_shot_iterator.string_handle(name="foo")
        self.assertEqual("foo_1", handle_with_same_name.op.name)
        self.assertIsNot(handle_with_name, handle_with_same_name)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testIteratorStringHandleError(self):
        dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices(
            [1, 2, 3]).repeat())
        dataset_float_vector = (dataset_ops.Dataset.from_tensors(
            [1.0, 2.0, 3.0]))

        handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])

        feedable_int_scalar = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [])
        feedable_int_vector = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32, [None])
        feedable_int_any = iterator_ops.Iterator.from_string_handle(
            handle_placeholder, dtypes.int32)

        with self.cached_session() as sess:
            handle_int_scalar = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_int_scalar).string_handle())
            handle_float_vector = sess.run(
                dataset_ops.make_one_shot_iterator(
                    dataset_float_vector).string_handle())

            self.assertEqual(
                1,
                sess.run(feedable_int_scalar.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            self.assertEqual(
                2,
                sess.run(feedable_int_any.get_next(),
                         feed_dict={handle_placeholder: handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(feedable_int_vector.get_next(),
                             feed_dict={handle_placeholder:
                                        handle_int_scalar}))

            with self.assertRaises(errors.InvalidArgumentError):
                print(
                    sess.run(
                        feedable_int_vector.get_next(),
                        feed_dict={handle_placeholder: handle_float_vector}))

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
        worker_config = config_pb2.ConfigProto()
        worker_config.device_count["CPU"] = 3

        with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        @function.Defun(dtypes.string)
        def _remote_fn(h):
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            remote_op = functional_ops.remote_call(args=[iterator_3_handle],
                                                   Tout=[dtypes.int32],
                                                   f=_remote_fn,
                                                   target=target_placeholder)

        with self.session(config=worker_config) as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [1])
            # Fails when target is cpu:2 where the resource is not located.
            with self.assertRaises(errors.InvalidArgumentError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:2"
                         })
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:1"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:1"
                         })

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testRemoteIteratorUsingRemoteCallOpMultiWorkers(self):
        s1 = server_lib.Server.create_local_server()
        s2 = server_lib.Server.create_local_server()
        s3 = server_lib.Server.create_local_server()

        cluster_def = cluster_pb2.ClusterDef()
        workers = cluster_def.job.add()
        workers.name = "worker"
        workers.tasks[0] = s1.target[len("grpc://"):]
        workers.tasks[1] = s2.target[len("grpc://"):]
        client = cluster_def.job.add()
        client.name = "client"
        client.tasks[0] = s3.target[len("grpc://"):]
        config = config_pb2.ConfigProto(cluster_def=cluster_def)

        worker_devices = [
            "/job:worker/replica:0/task:%d/cpu:0" % i for i in range(2)
        ]
        itr_handles = []
        for device in worker_devices:
            with ops.device(device):
                src = dataset_ops.Dataset.from_tensor_slices([device])
                itr = dataset_ops.make_one_shot_iterator(src)
                itr_handles.append(itr.string_handle())

        targets = dataset_ops.Dataset.from_tensor_slices(worker_devices)
        handles = dataset_ops.Dataset.from_tensor_slices(itr_handles)

        @function.Defun(dtypes.string)
        def loading_func(h):
            remote_itr = iterator_ops.Iterator.from_string_handle(
                h, dataset_ops.get_legacy_output_types(itr),
                dataset_ops.get_legacy_output_shapes(itr))
            return remote_itr.get_next()

        def map_fn(target, handle):
            return functional_ops.remote_call(args=[handle],
                                              Tout=[dtypes.string],
                                              f=loading_func,
                                              target=target)

        with ops.device("/job:client"):
            client_dataset = dataset_ops.Dataset.zip(
                (targets, handles)).map(map_fn)
            itr = dataset_ops.make_initializable_iterator(client_dataset)
            n = itr.get_next()

        with session.Session(s3.target, config=config) as sess:
            sess.run(itr.initializer)
            expected_values = worker_devices
            for expected in expected_values:
                self.assertEqual((compat.as_bytes(expected), ), sess.run(n))

            with self.assertRaises(errors.OutOfRangeError):
                sess.run(n)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
            dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
            iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
            iterator_3_handle = iterator_3.string_handle()

        def _encode_raw(byte_array):
            return bytes(bytearray(byte_array))

        @function.Defun(dtypes.uint8)
        def _remote_fn(h):
            handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
            remote_iterator = iterator_ops.Iterator.from_string_handle(
                handle, dataset_ops.get_legacy_output_types(dataset_3),
                dataset_ops.get_legacy_output_shapes(dataset_3))
            return remote_iterator.get_next()

        with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
            target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
            iterator_3_handle_uint8 = parsing_ops.decode_raw(
                input_bytes=iterator_3_handle, out_type=dtypes.uint8)
            remote_op = functional_ops.remote_call(
                args=[iterator_3_handle_uint8],
                Tout=[dtypes.int32],
                f=_remote_fn,
                target=target_placeholder)

        with self.cached_session() as sess:
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [1])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [2])
            elem = sess.run(remote_op,
                            feed_dict={
                                target_placeholder:
                                "/job:localhost/replica:0/task:0/cpu:0"
                            })
            self.assertEqual(elem, [3])
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(remote_op,
                         feed_dict={
                             target_placeholder:
                             "/job:localhost/replica:0/task:0/cpu:0"
                         })

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2], mode=["graph"]))
    def testRepeatedGetNextWarning(self):
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.range(10))
        warnings.simplefilter("always")
        with warnings.catch_warnings(record=True) as w:
            for _ in range(100):
                iterator.get_next()
        self.assertEqual(100 - iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD,
                         len(w))
        for warning in w:
            self.assertIn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE,
                          str(warning.message))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=tensor_spec.TensorSpec(
                    [], dtypes.float32),
                expected_output_classes=ops.Tensor,
                expected_output_types=dtypes.float32,
                expected_output_shapes=[[]])))
    def testTensorIteratorStructure(self, expected_element_structure,
                                    expected_output_classes,
                                    expected_output_types,
                                    expected_output_shapes):
        tf_value_fn = lambda: constant_op.constant(37.0)
        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(
                expected_element_structure=sparse_tensor.SparseTensorSpec(
                    [1], dtypes.int32),
                expected_output_classes=sparse_tensor.SparseTensor,
                expected_output_types=dtypes.int32,
                expected_output_shapes=[[1]])))
    def testSparseTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return sparse_tensor.SparseTensor(indices=[[0]],
                                              values=constant_op.constant(
                                                  [0], dtype=dtypes.int32),
                                              dense_shape=[1])

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(
        combinations.times(
            test_base.default_test_combinations(),
            combinations.combine(expected_element_structure={
                "a":
                tensor_spec.TensorSpec([], dtypes.float32),
                "b": (tensor_spec.TensorSpec([1], dtypes.string),
                      tensor_spec.TensorSpec([], dtypes.string))
            },
                                 expected_output_classes={
                                     "a": ops.Tensor,
                                     "b": (ops.Tensor, ops.Tensor)
                                 },
                                 expected_output_types={
                                     "a": dtypes.float32,
                                     "b": (dtypes.string, dtypes.string)
                                 },
                                 expected_output_shapes={
                                     "a": [],
                                     "b": ([1], [])
                                 })))
    def testNestedTensorIteratorStructure(self, expected_element_structure,
                                          expected_output_classes,
                                          expected_output_types,
                                          expected_output_shapes):
        def tf_value_fn():
            return {
                "a": constant_op.constant(37.0),
                "b":
                (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
            }

        tf_value = tf_value_fn()
        iterator = dataset_ops.make_one_shot_iterator(
            dataset_ops.Dataset.from_tensors(tf_value))

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(iterator),
                                     expected_element_structure))
        self.assertEqual(expected_output_classes,
                         dataset_ops.get_legacy_output_classes(iterator))
        self.assertEqual(expected_output_types,
                         dataset_ops.get_legacy_output_types(iterator))
        self.assertEqual(expected_output_shapes,
                         dataset_ops.get_legacy_output_shapes(iterator))

    @combinations.generate(test_base.default_test_combinations())
    def testIteratorGetNextName(self):
        with ops.Graph().as_default():
            iterator = dataset_ops.make_one_shot_iterator(
                dataset_ops.Dataset.from_tensors(37.0))
            next_element = iterator.get_next(name="overridden_name")
            self.assertEqual("overridden_name", next_element.op.name)

    @combinations.generate(
        combinations.combine(tf_api_version=[1, 2],
                             mode="eager",
                             execution_mode=[context.ASYNC, context.SYNC]))
    def testIteratorEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for foo in iterator:
                self.assertEqual(val, foo.numpy())
                val += 1

    @combinations.generate(combinations.combine(tf_api_version=2,
                                                mode="eager"))
    def testIteratorV2Function(self):

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)

        @def_function.function
        def fn():
            dataset = dataset_ops.Dataset.range(10)
            iterator = iter(dataset)
            for _ in range(10):
                queue.enqueue(next(iterator))

        fn()

        for i in range(10):
            self.assertEqual(queue.dequeue().numpy(), i)

    @combinations.generate(combinations.combine(tf_api_version=2,
                                                mode="eager"))
    def testIteratorV2FunctionError(self):
        # In this test we verify that a function that raises an error ends up
        # properly deallocating the iterator resource.

        queue = data_flow_ops.FIFOQueue(10, dtypes.int64)
        queue.enqueue(0)

        def init_fn(n):
            return n

        def next_fn(_):
            ds = dataset_ops.Dataset.range(0)
            return next(iter(ds))

        def finalize_fn(n):
            queue.enqueue(0)
            return n

        @def_function.function
        def fn():
            dataset = dataset_ops._GeneratorDataset(1, init_fn, next_fn,
                                                    finalize_fn)
            iterator = iter(dataset)
            next(iterator)

        with self.assertRaises(errors.OutOfRangeError):
            fn()

        self.assertEqual(queue.size().numpy(), 2)

    @combinations.generate(combinations.combine(tf_api_version=2,
                                                mode="eager"))
    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(iterator):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in iterator:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(iter(dataset))), 10)
            self.assertEqual(self.evaluate(f(iter(dataset2))), 45)
            self.assertEqual(trace_count[0], 1)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testFromValue(self):
    opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual(37.0, self.evaluate(opt.get_value()))

  def testFromStructuredValue(self):
    opt = optional_ops.Optional.from_value({
        "a": constant_op.constant(37.0),
        "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
    })
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual({
        "a": 37.0,
        "b": ([b"Foo"], b"Bar")
    }, self.evaluate(opt.get_value()))

  def testFromSparseTensor(self):
    st_0 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0]]),
        values=np.array([0], dtype=np.int64),
        dense_shape=np.array([1]))
    st_1 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0, 0], [1, 1]]),
        values=np.array([-1., 1.], dtype=np.float32),
        dense_shape=np.array([2, 2]))
    opt = optional_ops.Optional.from_value((st_0, st_1))
    self.assertTrue(self.evaluate(opt.has_value()))
    val_0, val_1 = opt.get_value()
    for expected, actual in [(st_0, val_0), (st_1, val_1)]:
      self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
      self.assertAllEqual(expected.values, self.evaluate(actual.values))
      self.assertAllEqual(expected.dense_shape,
                          self.evaluate(actual.dense_shape))

  def testFromNone(self):
    value_structure = tensor_spec.TensorSpec([], dtypes.float32)
    opt = optional_ops.Optional.none_from_structure(value_structure)
    self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            tensor_spec.TensorSpec([1], dtypes.float32)))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            tensor_spec.TensorSpec([], dtypes.int32)))
    self.assertFalse(self.evaluate(opt.has_value()))
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(opt.get_value())

  def testAddN(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        # With value
        opt1 = optional_ops.Optional.from_value((1.0, 2.0))
        opt2 = optional_ops.Optional.from_value((3.0, 4.0))

        add_tensor = math_ops.add_n([opt1._variant_tensor,
                                     opt2._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
        self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))

        # Without value
        opt_none1 = optional_ops.Optional.none_from_structure(
            opt1.value_structure)
        opt_none2 = optional_ops.Optional.none_from_structure(
            opt2.value_structure)
        add_tensor = math_ops.add_n([opt_none1._variant_tensor,
                                     opt_none2._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor,
                                             opt_none1.value_structure)
        self.assertFalse(self.evaluate(add_opt.has_value()))

  def testNestedAddN(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        opt1 = optional_ops.Optional.from_value([1, 2.0])
        opt2 = optional_ops.Optional.from_value([3, 4.0])
        opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
        opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))

        add_tensor = math_ops.add_n([opt3._variant_tensor,
                                     opt4._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
        self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)

        inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
                                                   opt1.value_structure)
        self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])

  def testZerosLike(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        # With value
        opt = optional_ops.Optional.from_value((1.0, 2.0))
        zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt.value_structure)
        self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
                            (0.0, 0.0))

        # Without value
        opt_none = optional_ops.Optional.none_from_structure(
            opt.value_structure)
        zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt_none.value_structure)
        self.assertFalse(self.evaluate(zeros_opt.has_value()))

  def testNestedZerosLike(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        opt1 = optional_ops.Optional.from_value(1.0)
        opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)

        zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt2.value_structure)
        inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
                                                     opt1.value_structure)
        self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)

  def testCopyToGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/cpu:0"):
      optional_with_value = optional_ops.Optional.from_value(
          (constant_op.constant(37.0), constant_op.constant("Foo"),
           constant_op.constant(42)))
      optional_none = optional_ops.Optional.none_from_structure(
          tensor_spec.TensorSpec([], dtypes.float32))

    with ops.device("/gpu:0"):
      gpu_optional_with_value = optional_ops._OptionalImpl(
          array_ops.identity(optional_with_value._variant_tensor),
          optional_with_value.value_structure)
      gpu_optional_none = optional_ops._OptionalImpl(
          array_ops.identity(optional_none._variant_tensor),
          optional_none.value_structure)

      gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
      gpu_optional_with_value_values = gpu_optional_with_value.get_value()

      gpu_optional_none_has_value = gpu_optional_none.has_value()

    self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
    self.assertEqual((37.0, b"Foo", 42),
                     self.evaluate(gpu_optional_with_value_values))
    self.assertFalse(self.evaluate(gpu_optional_none_has_value))

  def testNestedCopyToGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/cpu:0"):
      optional_with_value = optional_ops.Optional.from_value(
          (constant_op.constant(37.0), constant_op.constant("Foo"),
           constant_op.constant(42)))
      optional_none = optional_ops.Optional.none_from_structure(
          tensor_spec.TensorSpec([], dtypes.float32))
      nested_optional = optional_ops.Optional.from_value(
          (optional_with_value._variant_tensor, optional_none._variant_tensor,
           1.0))

    with ops.device("/gpu:0"):
      gpu_nested_optional = optional_ops._OptionalImpl(
          array_ops.identity(nested_optional._variant_tensor),
          nested_optional.value_structure)

      gpu_nested_optional_has_value = gpu_nested_optional.has_value()
      gpu_nested_optional_values = gpu_nested_optional.get_value()

    self.assertTrue(self.evaluate(gpu_nested_optional_has_value))

    inner_with_value = optional_ops._OptionalImpl(
        gpu_nested_optional_values[0], optional_with_value.value_structure)

    inner_none = optional_ops._OptionalImpl(
        gpu_nested_optional_values[1], optional_none.value_structure)

    self.assertEqual((37.0, b"Foo", 42),
                     self.evaluate(inner_with_value.get_value()))
    self.assertFalse(self.evaluate(inner_none.has_value()))
    self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))

  def _assertElementValueEqual(self, expected, actual):
    if isinstance(expected, dict):
      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
      for k in expected.keys():
        self._assertElementValueEqual(expected[k], actual[k])
    elif isinstance(expected, sparse_tensor.SparseTensorValue):
      self.assertAllEqual(expected.indices, actual.indices)
      self.assertAllEqual(expected.values, actual.values)
      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
    else:
      self.assertAllEqual(expected, actual)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       tensor_spec.TensorSpec([], dtypes.float32)),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0, 1]],
          values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[10, 10]),
       sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
      }, {
          "a":
              tensor_spec.TensorSpec([], dtypes.float32),
          "b": (
              tensor_spec.TensorSpec([1], dtypes.string),
              tensor_spec.TensorSpec([], dtypes.string),
          )
      }),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalSpec(
           tensor_spec.TensorSpec([], dtypes.float32))),
  )
  def testOptionalSpec(self, tf_value_fn, expected_value_structure):
    tf_value = tf_value_fn()
    opt = optional_ops.Optional.from_value(tf_value)

    self.assertTrue(
        structure.are_compatible(opt.value_structure, expected_value_structure))

    opt_structure = structure.type_spec_from_value(opt)
    self.assertIsInstance(opt_structure, optional_ops.OptionalSpec)
    self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
    self.assertTrue(
        structure.are_compatible(opt_structure._value_structure,
                                 expected_value_structure))
    self.assertEqual([dtypes.variant],
                     structure.get_flat_tensor_types(opt_structure))
    self.assertEqual([tensor_shape.TensorShape([])],
                     structure.get_flat_tensor_shapes(opt_structure))

    # All OptionalSpec objects are not compatible with a non-optional
    # value.
    non_optional_structure = structure.type_spec_from_value(
        constant_op.constant(42.0))
    self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))

    # Assert that the optional survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_opt = opt_structure._from_tensor_list(
        opt_structure._to_tensor_list(opt))
    if isinstance(tf_value, optional_ops.Optional):
      self._assertElementValueEqual(
          self.evaluate(tf_value.get_value()),
          self.evaluate(round_trip_opt.get_value().get_value()))
    else:
      self._assertElementValueEqual(
          self.evaluate(tf_value),
          self.evaluate(round_trip_opt.get_value()))

  @parameterized.named_parameters(
      ("Tensor", np.array([1, 2, 3], dtype=np.int32),
       lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
      ("SparseTensor", sparse_tensor.SparseTensorValue(
          indices=[[0, 0], [1, 1]],
          values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
       False),
      ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
                "b": sparse_tensor.SparseTensorValue(
                    indices=[[0, 0], [1, 1]],
                    values=np.array([-1., 1.], dtype=np.float32),
                    dense_shape=[2, 2])},
       lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
                "b": sparse_tensor.SparseTensor(
                    indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
                    dense_shape=[2, 2])}, False),
  )
  def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                    works_on_gpu):
    if not works_on_gpu and test.is_gpu_available():
      self.skipTest("Test case not yet supported on GPU.")
    ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

    if context.executing_eagerly():
      iterator = dataset_ops.make_one_shot_iterator(ds)
      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      for _ in range(3):
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertIsInstance(next_elem, optional_ops.Optional)
        self.assertTrue(structure.are_compatible(
            next_elem.value_structure,
            structure.type_spec_from_value(tf_value_fn())))
        self.assertTrue(next_elem.has_value())
        self._assertElementValueEqual(np_value, next_elem.get_value())
      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertFalse(self.evaluate(next_elem.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
          self.evaluate(next_elem.get_value())
    else:
      iterator = dataset_ops.make_initializable_iterator(ds)
      next_elem = iterator_ops.get_next_as_optional(iterator)
      self.assertIsInstance(next_elem, optional_ops.Optional)
      self.assertTrue(structure.are_compatible(
          next_elem.value_structure,
          structure.type_spec_from_value(tf_value_fn())))
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError. This is only relevant in graph mode.
      elem_has_value_t = next_elem.has_value()
      elem_value_t = next_elem.get_value()
      with self.assertRaises(errors.FailedPreconditionError):
        self.evaluate(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        self.evaluate(elem_value_t)
      # Now we initialize the iterator.
      self.evaluate(iterator.initializer)
      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      for _ in range(3):
        elem_has_value, elem_value = self.evaluate(
            [elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self._assertElementValueEqual(np_value, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(self.evaluate(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          self.evaluate(elem_value_t)

  def testFunctionBoundaries(self):
    @def_function.function
    def get_optional():
      x = constant_op.constant(1.0)
      opt = optional_ops.Optional.from_value(x)
      # TODO(skyewm): support returning Optionals from functions?
      return opt._variant_tensor

    # TODO(skyewm): support Optional arguments?
    @def_function.function
    def consume_optional(opt_tensor):
      value_structure = tensor_spec.TensorSpec([], dtypes.float32)
      opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
      return opt.get_value()

    opt_tensor = get_optional()
    val = consume_optional(opt_tensor)
    self.assertEqual(self.evaluate(val), 1.0)

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(opt):
      trace_count[0] += 1
      return opt.get_value()

    opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0))
    opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0))

    for _ in range(10):
      self.assertEqual(self.evaluate(f(opt1)), 37.0)
      self.assertEqual(self.evaluate(f(opt2)), 42.0)
      self.assertEqual(trace_count[0], 1)
Exemple #25
0
def _SparseTensorStructure(dtype, shape):
    return sparse_tensor.SparseTensorSpec(shape, dtype)
Exemple #26
0
    def __init__(self, input_dataset, features, num_parallel_calls,
                 deterministic):
        self._input_dataset = input_dataset
        if not structure.are_compatible(
                input_dataset.element_spec,
                tensor_spec.TensorSpec([None], dtypes.string)):
            raise TypeError(
                "Input dataset should be a dataset of vectors of strings")
        self._num_parallel_calls = num_parallel_calls
        if deterministic is None:
            self._deterministic = "default"
        elif deterministic:
            self._deterministic = "true"
        else:
            self._deterministic = "false"
        # pylint: disable=protected-access
        self._features = parsing_ops._prepend_none_dimension(features)
        # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature
        params = parsing_ops._ParseOpParams.from_features(
            self._features, [
                parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
                parsing_ops.FixedLenFeature,
                parsing_ops.FixedLenSequenceFeature, parsing_ops.RaggedFeature
            ])
        # pylint: enable=protected-access
        self._sparse_keys = params.sparse_keys
        self._sparse_types = params.sparse_types
        self._ragged_keys = params.ragged_keys
        self._ragged_value_types = params.ragged_value_types
        self._ragged_split_types = params.ragged_split_types
        self._dense_keys = params.dense_keys
        self._dense_defaults = params.dense_defaults_vec
        self._dense_shapes = params.dense_shapes_as_proto
        self._dense_types = params.dense_types
        input_dataset_shape = dataset_ops.get_legacy_output_shapes(
            self._input_dataset)

        self._element_spec = {}

        for (key, value_type) in zip(params.sparse_keys, params.sparse_types):
            self._element_spec[key] = sparse_tensor.SparseTensorSpec(
                input_dataset_shape.concatenate([None]), value_type)

        for (key, value_type, dense_shape) in zip(params.dense_keys,
                                                  params.dense_types,
                                                  params.dense_shapes):
            self._element_spec[key] = tensor_spec.TensorSpec(
                input_dataset_shape.concatenate(dense_shape), value_type)

        for (key, value_type, splits_type) in zip(params.ragged_keys,
                                                  params.ragged_value_types,
                                                  params.ragged_split_types):
            self._element_spec[key] = ragged_tensor.RaggedTensorSpec(
                input_dataset_shape.concatenate([None]), value_type, 1,
                splits_type)

        variant_tensor = (
            gen_experimental_dataset_ops.parse_example_dataset_v2(
                self._input_dataset._variant_tensor,  # pylint: disable=protected-access
                self._num_parallel_calls,
                self._dense_defaults,
                self._sparse_keys,
                self._dense_keys,
                self._sparse_types,
                self._dense_shapes,
                deterministic=self._deterministic,
                ragged_keys=self._ragged_keys,
                ragged_value_types=self._ragged_value_types,
                ragged_split_types=self._ragged_split_types,
                **self._flat_structure))
        super(_ParseExampleDataset, self).__init__(input_dataset,
                                                   variant_tensor)
Exemple #27
0
 def common_spec(x, y):
   common_shape = defun.common_shape(x.shape, y.shape)
   if isinstance(x, sparse_tensor.SparseTensorSpec):
     return sparse_tensor.SparseTensorSpec(common_shape, x.dtype)
   return tensor_spec.TensorSpec(common_shape, x.dtype, x.name)
Exemple #28
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).flat_map(
                lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, list(original_dataset))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertLen(dataset_fn()._inputs(), num_inputs)

    @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testNoWarnings(self):
        with test.mock.patch.object(warnings, "warn") as mock_log:
            dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
            dataset_fn(dataset_ops.Dataset.range(10))
            self.assertEmpty(mock_log.call_args_list)

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), sparse_tensor.SparseTensorSpec([1],
                                                             dtypes.int32)),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (
                tensor_spec.TensorSpec([1], dtypes.string),
                tensor_spec.TensorSpec([], dtypes.string),
            )
        }),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([],
                                                          dtypes.float32))),
    )
    def testDatasetSpec(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.scalar()],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset_ops.make_one_shot_iterator(dataset)
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @parameterized.named_parameters(
        ("Async", context.ASYNC),
        ("Sync", context.SYNC),
    )
    def testDatasetEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        with ops.device("CPU"):
            first_dataset = dataset_ops.Dataset.range(10)
            self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
            second_dataset = dataset_ops.Dataset.range(11)
            self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
            first_concrete = _uses_dataset.get_concrete_function(first_dataset)
            # The dataset should not be a captured input
            self.assertEmpty(first_concrete.graph.captures)
            # The two datasets have the same structure and so should re-use a trace.
            self.assertIs(first_concrete,
                          _uses_dataset.get_concrete_function(second_dataset))
            # With a different structure we should use a different trace.
            self.assertIsNot(
                first_concrete,
                _uses_dataset.get_concrete_function(
                    dataset_ops.Dataset.zip((first_dataset, second_dataset))))

    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(ds):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in ds:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(dataset)), 10)
            self.assertEqual(self.evaluate(f(dataset2)), 45)
            self.assertEqual(trace_count[0], 1)
Exemple #29
0
 def compute_output_signature(self, input_spec):
   input_shapes = [x.shape for x in input_spec]
   output_shape = self.compute_output_shape(input_shapes)
   output_dtype = dtypes.int64 if self.num_bins else dtypes.string
   return sparse_tensor.SparseTensorSpec(
       shape=output_shape, dtype=output_dtype)
Exemple #30
0
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
                    test_util.TensorFlowTestCase):

    # pylint: disable=g-long-lambda,protected-access
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
         [dtypes.float32], [[]]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
         ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
        ("Nested_0", lambda:
         (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), tuple,
         [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_1", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, dict, [dtypes.float32, dtypes.int32], [[], [3]]),
        ("Nested_2", lambda: {
            "a":
            constant_op.constant(37.0),
            "b":
            (sparse_tensor.
             SparseTensor(indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
             sparse_tensor.SparseTensor(
                 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
        }, dict, [dtypes.float32, dtypes.variant, dtypes.variant], [[], None,
                                                                    None]),
    )
    def testFlatStructure(self, value_fn, expected_structure, expected_types,
                          expected_shapes):
        value = value_fn()
        s = structure.type_spec_from_value(value)
        self.assertIsInstance(s, expected_structure)
        flat_types = structure.get_flat_tensor_types(s)
        self.assertEqual(expected_types, flat_types)
        flat_shapes = structure.get_flat_tensor_shapes(s)
        self.assertLen(flat_shapes, len(expected_shapes))
        for expected, actual in zip(expected_shapes, flat_shapes):
            if expected is None:
                self.assertEqual(actual.ndims, None)
            else:
                self.assertEqual(actual.as_list(), expected)

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0), lambda: [
            constant_op.constant(38.0),
            array_ops.placeholder(dtypes.float32),
            variables.Variable(100.0), 42.0,
            np.array(42.0, dtype=np.float32)
        ],
         lambda: [constant_op.constant([1.0, 2.0]),
                  constant_op.constant(37)]),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0), lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(3, ), size=10)
            ], lambda: [
                tensor_array_ops.TensorArray(
                    dtype=dtypes.int32, element_shape=(3, ), size=0),
                tensor_array_ops.TensorArray(
                    dtype=dtypes.float32, element_shape=(), size=0)
            ]),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [
                sparse_tensor.SparseTensor(indices=[[1, 1], [3, 4]],
                                           values=[10, -1],
                                           dense_shape=[4, 5]),
                sparse_tensor.SparseTensorValue(indices=[[1, 1], [3, 4]],
                                                values=[10, -1],
                                                dense_shape=[4, 5]),
                array_ops.sparse_placeholder(dtype=dtypes.int32),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None])
            ], lambda: [
                constant_op.constant(37, shape=[4, 5]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
                array_ops.sparse_placeholder(dtype=dtypes.int32,
                                             shape=[None, None, None]),
                sparse_tensor.SparseTensor(
                    indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
            ]),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), lambda: [
             ragged_factory_ops.constant([[1, 2], [3, 4], []]),
             ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
         ], lambda: [
             ragged_factory_ops.constant(1),
             ragged_factory_ops.constant([1, 2]),
             ragged_factory_ops.constant([[1], [2]]),
             ragged_factory_ops.constant([["a", "b"]]),
         ]),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6])
        }], lambda: [{
            "a": constant_op.constant(15.0),
            "b": constant_op.constant([4, 5, 6, 7])
        }, {
            "a": constant_op.constant(15),
            "b": constant_op.constant([4, 5, 6])
        }, {
            "a":
            constant_op.constant(15),
            "b":
            sparse_tensor.SparseTensor(
                indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
        }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
    )
    @test_util.run_deprecated_v1
    def testIsCompatibleWithStructure(self, original_value_fn,
                                      compatible_values_fn,
                                      incompatible_values_fn):
        original_value = original_value_fn()
        compatible_values = compatible_values_fn()
        incompatible_values = incompatible_values_fn()
        s = structure.type_spec_from_value(original_value)
        for compatible_value in compatible_values:
            self.assertTrue(
                structure.are_compatible(
                    s, structure.type_spec_from_value(compatible_value)))
        for incompatible_value in incompatible_values:
            self.assertFalse(
                structure.are_compatible(
                    s, structure.type_spec_from_value(incompatible_value)))

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
        ("RaggedTensor",
         lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
         lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), lambda:
         ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
         lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
         lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )  # pyformat: disable
    def testStructureFromValueEquality(self, value1_fn, value2_fn,
                                       *not_equal_value_fns):
        # pylint: disable=g-generic-assert
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        self.assertEqual(s1, s1)  # check __eq__ operator.
        self.assertEqual(s1, s2)  # check __eq__ operator.
        self.assertFalse(s1 != s1)  # check __ne__ operator.
        self.assertFalse(s1 != s2)  # check __ne__ operator.
        for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
        for value_fn in not_equal_value_fns:
            s3 = structure.type_spec_from_value(value_fn())
            self.assertNotEqual(s1, s3)  # check __ne__ operator.
            self.assertNotEqual(s2, s3)  # check __ne__ operator.
            self.assertFalse(s1 == s3)  # check __eq_ operator.
            self.assertFalse(s2 == s3)  # check __eq_ operator.

    @parameterized.named_parameters(
        ("RaggedTensor_RaggedRank",
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)),
        ("RaggedTensor_Shape",
         ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)),
        ("RaggedTensor_DType",
         ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)),
    )
    def testRaggedStructureInequality(self, s1, s2):
        # pylint: disable=g-generic-assert
        self.assertNotEqual(s1, s2)  # check __ne__ operator.
        self.assertFalse(s1 == s2)  # check __eq__ operator.

    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         lambda: constant_op.constant(42.0),
         lambda: constant_op.constant([5])),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.float32, element_shape=(3, ), size=0),
         lambda: tensor_array_ops.TensorArray(
             dtype=dtypes.int32, element_shape=(), size=0)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[1, 2]], values=[42], dense_shape=[4, 5]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[3]], values=[-1], dense_shape=[5])),
        ("Nested", lambda: {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }, lambda: {
            "a": constant_op.constant(42.0),
            "b": constant_op.constant([4, 5, 6])
        }, lambda: {
            "a": constant_op.constant([1, 2, 3]),
            "b": constant_op.constant(37.0)
        }),
    )
    def testHash(self, value1_fn, value2_fn, value3_fn):
        s1 = structure.type_spec_from_value(value1_fn())
        s2 = structure.type_spec_from_value(value2_fn())
        s3 = structure.type_spec_from_value(value3_fn())
        for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2),
                              nest.flatten(s3)):
            self.assertEqual(hash(c1), hash(c1))
            self.assertEqual(hash(c1), hash(c2))
            self.assertNotEqual(hash(c1), hash(c3))
            self.assertNotEqual(hash(c2), hash(c3))

    @parameterized.named_parameters(
        (
            "Tensor",
            lambda: constant_op.constant(37.0),
        ),
        (
            "SparseTensor",
            lambda: sparse_tensor.SparseTensor(
                indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
        ),
        ("TensorArray", lambda: tensor_array_ops.TensorArray(
            dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
        (
            "RaggedTensor",
            lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
        ),
        (
            "Nested_0",
            lambda: {
                "a": constant_op.constant(37.0),
                "b": constant_op.constant([1, 2, 3])
            },
        ),
        (
            "Nested_1",
            lambda: {
                "a":
                constant_op.constant(37.0),
                "b": (sparse_tensor.SparseTensor(
                    indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
                      sparse_tensor.SparseTensor(
                          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
            },
        ),
    )
    def testRoundTripConversion(self, value_fn):
        value = value_fn()
        s = structure.type_spec_from_value(value)

        def maybe_stack_ta(v):
            if isinstance(v, tensor_array_ops.TensorArray):
                return v.stack()
            else:
                return v

        before = self.evaluate(maybe_stack_ta(value))
        after = self.evaluate(
            maybe_stack_ta(
                structure.from_tensor_list(s,
                                           structure.to_tensor_list(s,
                                                                    value))))

        flat_before = nest.flatten(before)
        flat_after = nest.flatten(after)
        for b, a in zip(flat_before, flat_after):
            if isinstance(b, sparse_tensor.SparseTensorValue):
                self.assertAllEqual(b.indices, a.indices)
                self.assertAllEqual(b.values, a.values)
                self.assertAllEqual(b.dense_shape, a.dense_shape)
            elif isinstance(b, (ragged_tensor.RaggedTensor,
                                ragged_tensor_value.RaggedTensorValue)):
                self.assertAllEqual(b, a)
            else:
                self.assertAllEqual(b, a)

    # pylint: enable=g-long-lambda

    def preserveStaticShape(self):
        rt = ragged_factory_ops.constant([[1, 2], [], [3]])
        rt_s = structure.type_spec_from_value(rt)
        rt_after = structure.from_tensor_list(
            rt_s, structure.to_tensor_list(rt_s, rt))
        self.assertEqual(rt_after.row_splits.shape.as_list(),
                         rt.row_splits.shape.as_list())
        self.assertEqual(rt_after.values.shape.as_list(), [None])

        st = sparse_tensor.SparseTensor(indices=[[3, 4]],
                                        values=[-1],
                                        dense_shape=[4, 5])
        st_s = structure.type_spec_from_value(st)
        st_after = structure.from_tensor_list(
            st_s, structure.to_tensor_list(st_s, st))
        self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
        self.assertEqual(st_after.values.shape.as_list(), [None])
        self.assertEqual(st_after.dense_shape.shape.as_list(),
                         st.dense_shape.shape.as_list())

    def testPreserveTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32,
                                          size=1,
                                          element_shape=(3, ))
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    def testPreserveInferredTensorArrayShape(self):
        ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
        # Shape is inferred from the write.
        ta = ta.write(0, [1, 2, 3])
        ta_s = structure.type_spec_from_value(ta)
        ta_after = structure.from_tensor_list(
            ta_s, structure.to_tensor_list(ta_s, ta))
        self.assertEqual(ta_after.element_shape.as_list(), [3])

    def testIncompatibleStructure(self):
        # Define three mutually incompatible values/structures, and assert that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.
        value_tensor = constant_op.constant(42.0)
        s_tensor = structure.type_spec_from_value(value_tensor)
        flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)

        value_sparse_tensor = sparse_tensor.SparseTensor(indices=[[0, 0]],
                                                         values=[1],
                                                         dense_shape=[1, 1])
        s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
        flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
                                                      value_sparse_tensor)

        value_nest = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_nest = structure.type_spec_from_value(value_nest)
        flat_nest = structure.to_tensor_list(s_nest, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*float32.* and shape \(\)"):
            structure.to_tensor_list(s_tensor, value_sparse_tensor)
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_tensor, value_nest)

        with self.assertRaisesRegex(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_sparse_tensor, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_sparse_tensor, value_nest)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_tensor)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_nest, value_sparse_tensor)

        with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_tensor, flat_sparse_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_sparse_tensor, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 1 tensors but got 2."):
            structure.from_tensor_list(s_sparse_tensor, flat_nest)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_tensor)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 1."):
            structure.from_tensor_list(s_nest, flat_sparse_tensor)

    def testIncompatibleNestedStructure(self):
        # Define three mutually incompatible nested values/structures, and assert
        # that:
        # 1. Using one structure to flatten a value with an incompatible structure
        #    fails.
        # 2. Using one structure to restructure a flattened value with an
        #    incompatible structure fails.

        value_0 = {
            "a": constant_op.constant(37.0),
            "b": constant_op.constant([1, 2, 3])
        }
        s_0 = structure.type_spec_from_value(value_0)
        flat_s_0 = structure.to_tensor_list(s_0, value_0)

        # `value_1` has compatible nested structure with `value_0`, but different
        # classes.
        value_1 = {
            "a":
            constant_op.constant(37.0),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 0]],
                                       values=[1],
                                       dense_shape=[1, 1])
        }
        s_1 = structure.type_spec_from_value(value_1)
        flat_s_1 = structure.to_tensor_list(s_1, value_1)

        # `value_2` has incompatible nested structure with `value_0` and `value_1`.
        value_2 = {
            "a":
            constant_op.constant(37.0),
            "b": (sparse_tensor.SparseTensor(indices=[[0, 0]],
                                             values=[1],
                                             dense_shape=[1, 1]),
                  sparse_tensor.SparseTensor(indices=[[3, 4]],
                                             values=[-1],
                                             dense_shape=[4, 5]))
        }
        s_2 = structure.type_spec_from_value(value_2)
        flat_s_2 = structure.to_tensor_list(s_2, value_2)

        with self.assertRaisesRegex(
                ValueError,
                r"SparseTensor.* is not convertible to a tensor with "
                r"dtype.*int32.* and shape \(3,\)"):
            structure.to_tensor_list(s_0, value_1)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_0, value_2)

        with self.assertRaisesRegex(
                TypeError, "Neither a SparseTensor nor SparseTensorValue"):
            structure.to_tensor_list(s_1, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_1, value_2)

        # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
        # needs to account for "a" coming before or after "b". It might be worth
        # adding a deterministic repr for these error messages (among other
        # improvements).
        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_0)

        with self.assertRaisesRegex(
                ValueError,
                "The two structures don't have the same nested structure."):
            structure.to_tensor_list(s_2, value_1)

        with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
            structure.from_tensor_list(s_0, flat_s_1)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_0, flat_s_2)

        with self.assertRaisesRegex(ValueError, "Incompatible input: "):
            structure.from_tensor_list(s_1, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 2 tensors but got 3."):
            structure.from_tensor_list(s_1, flat_s_2)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_0)

        with self.assertRaisesRegex(ValueError,
                                    "Expected 3 tensors but got 2."):
            structure.from_tensor_list(s_2, flat_s_1)

    @parameterized.named_parameters(
        ("Tensor", dtypes.float32, tensor_shape.TensorShape(
            []), ops.Tensor, tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor", dtypes.int32, tensor_shape.TensorShape(
            [2, 2]), sparse_tensor.SparseTensor,
         sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
        ("TensorArray_0", dtypes.int32,
         tensor_shape.TensorShape([None, True, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=None, infer_shape=True)),
        ("TensorArray_1", dtypes.int32,
         tensor_shape.TensorShape([True, None, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=True, infer_shape=None)),
        ("TensorArray_2", dtypes.int32,
         tensor_shape.TensorShape([True, False, 2, 2
                                   ]), tensor_array_ops.TensorArray,
         tensor_array_ops.TensorArraySpec(
             [2, 2], dtypes.int32, dynamic_size=True, infer_shape=False)),
        ("RaggedTensor", dtypes.int32, tensor_shape.TensorShape([2, None]),
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
        ("Nested", {
            "a": dtypes.float32,
            "b": (dtypes.int32, dtypes.string)
        }, {
            "a": tensor_shape.TensorShape([]),
            "b":
            (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
        }, {
            "a": ops.Tensor,
            "b": (sparse_tensor.SparseTensor, ops.Tensor)
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }),
    )
    def testConvertLegacyStructure(self, output_types, output_shapes,
                                   output_classes, expected_structure):
        actual_structure = structure.convert_legacy_structure(
            output_types, output_shapes, output_classes)
        self.assertEqual(actual_structure, expected_structure)

    def testNestedNestedStructure(self):
        s = (tensor_spec.TensorSpec([], dtypes.int64),
             (tensor_spec.TensorSpec([], dtypes.float32),
              tensor_spec.TensorSpec([], dtypes.string)))

        int64_t = constant_op.constant(37, dtype=dtypes.int64)
        float32_t = constant_op.constant(42.0)
        string_t = constant_op.constant("Foo")

        nested_tensors = (int64_t, (float32_t, string_t))

        tensor_list = structure.to_tensor_list(s, nested_tensors)
        for expected, actual in zip([int64_t, float32_t, string_t],
                                    tensor_list):
            self.assertIs(expected, actual)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = structure.from_tensor_list(s, tensor_list)
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

        (actual_int64_t,
         (actual_float32_t,
          actual_string_t)) = (structure.from_compatible_tensor_list(
              s, tensor_list))
        self.assertIs(int64_t, actual_int64_t)
        self.assertIs(float32_t, actual_float32_t)
        self.assertIs(string_t, actual_string_t)

    @parameterized.named_parameters(
        ("Tensor", tensor_spec.TensorSpec([], dtypes.float32), 32,
         tensor_spec.TensorSpec([32], dtypes.float32)),
        ("TensorUnknown", tensor_spec.TensorSpec([], dtypes.float32), None,
         tensor_spec.TensorSpec([None], dtypes.float32)),
        ("SparseTensor", sparse_tensor.SparseTensorSpec([None],
                                                        dtypes.float32), 32,
         sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)),
        ("SparseTensorUnknown",
         sparse_tensor.SparseTensorSpec([4], dtypes.float32), None,
         sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)),
        ("RaggedTensor",
         ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
         ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
        ("RaggedTensorUnknown",
         ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None,
         ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
        ("Nested", {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }, 128, {
            "a":
            tensor_spec.TensorSpec([128], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([128], dtypes.string))
        }),
    )
    def testBatch(self, element_structure, batch_size,
                  expected_batched_structure):
        batched_structure = nest.map_structure(
            lambda component_spec: component_spec._batch(batch_size),
            element_structure)
        self.assertEqual(batched_structure, expected_batched_structure)

    @parameterized.named_parameters(
        ("Tensor", tensor_spec.TensorSpec(
            [32], dtypes.float32), tensor_spec.TensorSpec([], dtypes.float32)),
        ("TensorUnknown", tensor_spec.TensorSpec([None], dtypes.float32),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor",
         sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
         sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
        ("SparseTensorUnknown",
         sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
         sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
        ("RaggedTensor",
         ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        ("RaggedTensorUnknown",
         ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
         ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
        ("Nested", {
            "a":
            tensor_spec.TensorSpec([128], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([None], dtypes.string))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
                  tensor_spec.TensorSpec([], dtypes.string))
        }),
    )
    def testUnbatch(self, element_structure, expected_unbatched_structure):
        unbatched_structure = nest.map_structure(
            lambda component_spec: component_spec._unbatch(),
            element_structure)
        self.assertEqual(unbatched_structure, expected_unbatched_structure)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
         lambda: constant_op.constant([1.0, 2.0])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(
             indices=[[0]], values=[13], dense_shape=[2])),
        ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
         lambda: ragged_factory_ops.constant([[1]])),
        ("Nest", lambda:
         (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
         lambda: (constant_op.constant([1.0, 2.0]),
                  sparse_tensor.SparseTensor(
                      indices=[[0]], values=[13], dense_shape=[2]))),
    )
    def testToBatchedTensorList(self, value_fn, element_0_fn):
        batched_value = value_fn()
        s = structure.type_spec_from_value(batched_value)
        batched_tensor_list = structure.to_batched_tensor_list(
            s, batched_value)

        # The batch dimension is 2 for all of the test cases.
        # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
        # tensors in which we store sparse tensors.
        for t in batched_tensor_list:
            if t.dtype != dtypes.variant:
                self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))

        # Test that the 0th element from the unbatched tensor is equal to the
        # expected value.
        expected_element_0 = self.evaluate(element_0_fn())
        unbatched_s = nest.map_structure(
            lambda component_spec: component_spec._unbatch(), s)
        actual_element_0 = structure.from_tensor_list(
            unbatched_s, [t[0] for t in batched_tensor_list])

        for expected, actual in zip(nest.flatten(expected_element_0),
                                    nest.flatten(actual_element_0)):
            self.assertValuesEqual(expected, actual)

    # pylint: enable=g-long-lambda

    def testDatasetSpecConstructor(self):
        rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
        st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
        t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
        element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
        ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
        self.assertEqual(ds_struct._element_spec, element_spec)
        # Note: shape was automatically converted from a list to a TensorShape.
        self.assertEqual(ds_struct._dataset_shape,
                         tensor_shape.TensorShape([5]))

    def testCustomMapping(self):
        elem = CustomMap(foo=constant_op.constant(37.))
        spec = structure.type_spec_from_value(elem)
        self.assertIsInstance(spec, CustomMap)
        self.assertEqual(spec["foo"], tensor_spec.TensorSpec([],
                                                             dtypes.float32))

    def testObjectProxy(self):
        nt_type = collections.namedtuple("A", ["x", "y"])
        proxied = wrapt.ObjectProxy(nt_type(1, 2))
        proxied_spec = structure.type_spec_from_value(proxied)
        self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)),
                         proxied_spec)