Exemplo n.º 1
0
def _optional_spec_test_combinations():
    # pylint: disable=g-long-lambda
    cases = [
        ("Dense", lambda: constant_op.constant(37.0),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("Sparse", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 1]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[10, 10]),
         sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (
                tensor_spec.TensorSpec([1], dtypes.string),
                tensor_spec.TensorSpec([], dtypes.string),
            )
        }),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([],
                                                          dtypes.float32))),
    ]

    def reduce_fn(x, y):
        name, value_fn, expected_structure = y
        return x + combinations.combine(
            tf_value_fn=combinations.NamedObject(name, value_fn),
            expected_value_structure=expected_structure)

    return functools.reduce(reduce_fn, cases, [])
Exemplo n.º 2
0
 def testOptionalDatasetSpec(self):
     self.checkDatasetSpec(
         optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(
             tensor_spec.TensorSpec([], dtypes.float32)))
Exemplo n.º 3
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).map(
                lambda x: x * 2)
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        with ops.device("CPU"):
            original_dataset = dataset_ops.Dataset.range(5).flat_map(
                lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
            fn = original_dataset._trace_variant_creation()
            variant = fn()

            revived_dataset = dataset_ops._VariantDataset(
                variant, original_dataset.element_spec)
            self.assertDatasetProduces(revived_dataset, list(original_dataset))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertLen(dataset_fn()._inputs(), num_inputs)

    @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testNoWarnings(self):
        with test.mock.patch.object(warnings, "warn") as mock_log:
            dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
            dataset_fn(dataset_ops.Dataset.range(10))
            self.assertEmpty(mock_log.call_args_list)

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         tensor_spec.TensorSpec([], dtypes.float32)),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), sparse_tensor.SparseTensorSpec([1],
                                                             dtypes.int32)),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            tensor_spec.TensorSpec([], dtypes.float32),
            "b": (
                tensor_spec.TensorSpec([1], dtypes.string),
                tensor_spec.TensorSpec([], dtypes.string),
            )
        }),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalSpec(tensor_spec.TensorSpec([],
                                                          dtypes.float32))),
    )
    def testDatasetSpec(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.type_spec_from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)

        self.assertTrue(
            structure.are_compatible(dataset_ops.get_structure(dataset),
                                     expected_element_structure))
        self.assertEqual([dtypes.variant],
                         structure.get_flat_tensor_types(dataset_structure))
        self.assertEqual([tensor_shape.scalar()],
                         structure.get_flat_tensor_shapes(dataset_structure))

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset_ops.make_one_shot_iterator(dataset)
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @parameterized.named_parameters(
        ("Async", context.ASYNC),
        ("Sync", context.SYNC),
    )
    def testDatasetEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        with ops.device("CPU"):
            first_dataset = dataset_ops.Dataset.range(10)
            self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
            second_dataset = dataset_ops.Dataset.range(11)
            self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
            first_concrete = _uses_dataset.get_concrete_function(first_dataset)
            # The dataset should not be a captured input
            self.assertEmpty(first_concrete.graph.captures)
            # The two datasets have the same structure and so should re-use a trace.
            self.assertIs(first_concrete,
                          _uses_dataset.get_concrete_function(second_dataset))
            # With a different structure we should use a different trace.
            self.assertIsNot(
                first_concrete,
                _uses_dataset.get_concrete_function(
                    dataset_ops.Dataset.zip((first_dataset, second_dataset))))

    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(ds):
            trace_count[0] += 1
            counter = np.int64(0)
            for elem in ds:
                counter += elem
            return counter

        dataset = dataset_ops.Dataset.range(5)
        dataset2 = dataset_ops.Dataset.range(10)

        for _ in range(10):
            self.assertEqual(self.evaluate(f(dataset)), 10)
            self.assertEqual(self.evaluate(f(dataset2)), 45)
            self.assertEqual(trace_count[0], 1)
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):

  def testFromValue(self):
    opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual(37.0, self.evaluate(opt.get_value()))

  def testFromStructuredValue(self):
    opt = optional_ops.Optional.from_value({
        "a": constant_op.constant(37.0),
        "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
    })
    self.assertTrue(self.evaluate(opt.has_value()))
    self.assertEqual({
        "a": 37.0,
        "b": ([b"Foo"], b"Bar")
    }, self.evaluate(opt.get_value()))

  def testFromSparseTensor(self):
    st_0 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0]]),
        values=np.array([0], dtype=np.int64),
        dense_shape=np.array([1]))
    st_1 = sparse_tensor.SparseTensorValue(
        indices=np.array([[0, 0], [1, 1]]),
        values=np.array([-1., 1.], dtype=np.float32),
        dense_shape=np.array([2, 2]))
    opt = optional_ops.Optional.from_value((st_0, st_1))
    self.assertTrue(self.evaluate(opt.has_value()))
    val_0, val_1 = opt.get_value()
    for expected, actual in [(st_0, val_0), (st_1, val_1)]:
      self.assertAllEqual(expected.indices, self.evaluate(actual.indices))
      self.assertAllEqual(expected.values, self.evaluate(actual.values))
      self.assertAllEqual(expected.dense_shape,
                          self.evaluate(actual.dense_shape))

  def testFromNone(self):
    value_structure = tensor_spec.TensorSpec([], dtypes.float32)
    opt = optional_ops.Optional.none_from_structure(value_structure)
    self.assertTrue(opt.value_structure.is_compatible_with(value_structure))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            tensor_spec.TensorSpec([1], dtypes.float32)))
    self.assertFalse(
        opt.value_structure.is_compatible_with(
            tensor_spec.TensorSpec([], dtypes.int32)))
    self.assertFalse(self.evaluate(opt.has_value()))
    with self.assertRaises(errors.InvalidArgumentError):
      self.evaluate(opt.get_value())

  def testAddN(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        # With value
        opt1 = optional_ops.Optional.from_value((1.0, 2.0))
        opt2 = optional_ops.Optional.from_value((3.0, 4.0))

        add_tensor = math_ops.add_n([opt1._variant_tensor,
                                     opt2._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor, opt1.value_structure)
        self.assertAllEqual(self.evaluate(add_opt.get_value()), (4.0, 6.0))

        # Without value
        opt_none1 = optional_ops.Optional.none_from_structure(
            opt1.value_structure)
        opt_none2 = optional_ops.Optional.none_from_structure(
            opt2.value_structure)
        add_tensor = math_ops.add_n([opt_none1._variant_tensor,
                                     opt_none2._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor,
                                             opt_none1.value_structure)
        self.assertFalse(self.evaluate(add_opt.has_value()))

  def testNestedAddN(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        opt1 = optional_ops.Optional.from_value([1, 2.0])
        opt2 = optional_ops.Optional.from_value([3, 4.0])
        opt3 = optional_ops.Optional.from_value((5.0, opt1._variant_tensor))
        opt4 = optional_ops.Optional.from_value((6.0, opt2._variant_tensor))

        add_tensor = math_ops.add_n([opt3._variant_tensor,
                                     opt4._variant_tensor])
        add_opt = optional_ops._OptionalImpl(add_tensor, opt3.value_structure)
        self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)

        inner_add_opt = optional_ops._OptionalImpl(add_opt.get_value()[1],
                                                   opt1.value_structure)
        self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])

  def testZerosLike(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        # With value
        opt = optional_ops.Optional.from_value((1.0, 2.0))
        zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt.value_structure)
        self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
                            (0.0, 0.0))

        # Without value
        opt_none = optional_ops.Optional.none_from_structure(
            opt.value_structure)
        zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt_none.value_structure)
        self.assertFalse(self.evaluate(zeros_opt.has_value()))

  def testNestedZerosLike(self):
    devices = ["/cpu:0"]
    if test_util.is_gpu_available():
      devices.append("/gpu:0")
    for device in devices:
      with ops.device(device):
        opt1 = optional_ops.Optional.from_value(1.0)
        opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)

        zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
        zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                               opt2.value_structure)
        inner_zeros_opt = optional_ops._OptionalImpl(zeros_opt.get_value(),
                                                     opt1.value_structure)
        self.assertEqual(self.evaluate(inner_zeros_opt.get_value()), 0.0)

  def testCopyToGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/cpu:0"):
      optional_with_value = optional_ops.Optional.from_value(
          (constant_op.constant(37.0), constant_op.constant("Foo"),
           constant_op.constant(42)))
      optional_none = optional_ops.Optional.none_from_structure(
          tensor_spec.TensorSpec([], dtypes.float32))

    with ops.device("/gpu:0"):
      gpu_optional_with_value = optional_ops._OptionalImpl(
          array_ops.identity(optional_with_value._variant_tensor),
          optional_with_value.value_structure)
      gpu_optional_none = optional_ops._OptionalImpl(
          array_ops.identity(optional_none._variant_tensor),
          optional_none.value_structure)

      gpu_optional_with_value_has_value = gpu_optional_with_value.has_value()
      gpu_optional_with_value_values = gpu_optional_with_value.get_value()

      gpu_optional_none_has_value = gpu_optional_none.has_value()

    self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
    self.assertEqual((37.0, b"Foo", 42),
                     self.evaluate(gpu_optional_with_value_values))
    self.assertFalse(self.evaluate(gpu_optional_none_has_value))

  def testNestedCopyToGPU(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")

    with ops.device("/cpu:0"):
      optional_with_value = optional_ops.Optional.from_value(
          (constant_op.constant(37.0), constant_op.constant("Foo"),
           constant_op.constant(42)))
      optional_none = optional_ops.Optional.none_from_structure(
          tensor_spec.TensorSpec([], dtypes.float32))
      nested_optional = optional_ops.Optional.from_value(
          (optional_with_value._variant_tensor, optional_none._variant_tensor,
           1.0))

    with ops.device("/gpu:0"):
      gpu_nested_optional = optional_ops._OptionalImpl(
          array_ops.identity(nested_optional._variant_tensor),
          nested_optional.value_structure)

      gpu_nested_optional_has_value = gpu_nested_optional.has_value()
      gpu_nested_optional_values = gpu_nested_optional.get_value()

    self.assertTrue(self.evaluate(gpu_nested_optional_has_value))

    inner_with_value = optional_ops._OptionalImpl(
        gpu_nested_optional_values[0], optional_with_value.value_structure)

    inner_none = optional_ops._OptionalImpl(
        gpu_nested_optional_values[1], optional_none.value_structure)

    self.assertEqual((37.0, b"Foo", 42),
                     self.evaluate(inner_with_value.get_value()))
    self.assertFalse(self.evaluate(inner_none.has_value()))
    self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))

  def _assertElementValueEqual(self, expected, actual):
    if isinstance(expected, dict):
      self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
      for k in expected.keys():
        self._assertElementValueEqual(expected[k], actual[k])
    elif isinstance(expected, sparse_tensor.SparseTensorValue):
      self.assertAllEqual(expected.indices, actual.indices)
      self.assertAllEqual(expected.values, actual.values)
      self.assertAllEqual(expected.dense_shape, actual.dense_shape)
    else:
      self.assertAllEqual(expected, actual)

  # pylint: disable=g-long-lambda
  @parameterized.named_parameters(
      ("Tensor", lambda: constant_op.constant(37.0),
       tensor_spec.TensorSpec([], dtypes.float32)),
      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
          indices=[[0, 1]],
          values=constant_op.constant([0], dtype=dtypes.int32),
          dense_shape=[10, 10]),
       sparse_tensor.SparseTensorSpec([10, 10], dtypes.int32)),
      ("Nest", lambda: {
          "a": constant_op.constant(37.0),
          "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
      }, {
          "a":
              tensor_spec.TensorSpec([], dtypes.float32),
          "b": (
              tensor_spec.TensorSpec([1], dtypes.string),
              tensor_spec.TensorSpec([], dtypes.string),
          )
      }),
      ("Optional", lambda: optional_ops.Optional.from_value(37.0),
       optional_ops.OptionalSpec(
           tensor_spec.TensorSpec([], dtypes.float32))),
  )
  def testOptionalSpec(self, tf_value_fn, expected_value_structure):
    tf_value = tf_value_fn()
    opt = optional_ops.Optional.from_value(tf_value)

    self.assertTrue(
        structure.are_compatible(opt.value_structure, expected_value_structure))

    opt_structure = structure.type_spec_from_value(opt)
    self.assertIsInstance(opt_structure, optional_ops.OptionalSpec)
    self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
    self.assertTrue(
        structure.are_compatible(opt_structure._value_structure,
                                 expected_value_structure))
    self.assertEqual([dtypes.variant],
                     structure.get_flat_tensor_types(opt_structure))
    self.assertEqual([tensor_shape.TensorShape([])],
                     structure.get_flat_tensor_shapes(opt_structure))

    # All OptionalSpec objects are not compatible with a non-optional
    # value.
    non_optional_structure = structure.type_spec_from_value(
        constant_op.constant(42.0))
    self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))

    # Assert that the optional survives a round-trip via _from_tensor_list()
    # and _to_tensor_list().
    round_trip_opt = opt_structure._from_tensor_list(
        opt_structure._to_tensor_list(opt))
    if isinstance(tf_value, optional_ops.Optional):
      self._assertElementValueEqual(
          self.evaluate(tf_value.get_value()),
          self.evaluate(round_trip_opt.get_value().get_value()))
    else:
      self._assertElementValueEqual(
          self.evaluate(tf_value),
          self.evaluate(round_trip_opt.get_value()))

  @parameterized.named_parameters(
      ("Tensor", np.array([1, 2, 3], dtype=np.int32),
       lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
      ("SparseTensor", sparse_tensor.SparseTensorValue(
          indices=[[0, 0], [1, 1]],
          values=np.array([-1., 1.], dtype=np.float32), dense_shape=[2, 2]),
       lambda: sparse_tensor.SparseTensor(
           indices=[[0, 1], [1, 0]], values=[37.0, 42.0], dense_shape=[2, 2]),
       False),
      ("Nest", {"a": np.array([1, 2, 3], dtype=np.int32),
                "b": sparse_tensor.SparseTensorValue(
                    indices=[[0, 0], [1, 1]],
                    values=np.array([-1., 1.], dtype=np.float32),
                    dense_shape=[2, 2])},
       lambda: {"a": constant_op.constant([4, 5, 6], dtype=dtypes.int32),
                "b": sparse_tensor.SparseTensor(
                    indices=[[0, 1], [1, 0]], values=[37.0, 42.0],
                    dense_shape=[2, 2])}, False),
  )
  def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                    works_on_gpu):
    if not works_on_gpu and test.is_gpu_available():
      self.skipTest("Test case not yet supported on GPU.")
    ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

    if context.executing_eagerly():
      iterator = dataset_ops.make_one_shot_iterator(ds)
      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      for _ in range(3):
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertIsInstance(next_elem, optional_ops.Optional)
        self.assertTrue(structure.are_compatible(
            next_elem.value_structure,
            structure.type_spec_from_value(tf_value_fn())))
        self.assertTrue(next_elem.has_value())
        self._assertElementValueEqual(np_value, next_elem.get_value())
      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertFalse(self.evaluate(next_elem.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
          self.evaluate(next_elem.get_value())
    else:
      iterator = dataset_ops.make_initializable_iterator(ds)
      next_elem = iterator_ops.get_next_as_optional(iterator)
      self.assertIsInstance(next_elem, optional_ops.Optional)
      self.assertTrue(structure.are_compatible(
          next_elem.value_structure,
          structure.type_spec_from_value(tf_value_fn())))
      # Before initializing the iterator, evaluating the optional fails with
      # a FailedPreconditionError. This is only relevant in graph mode.
      elem_has_value_t = next_elem.has_value()
      elem_value_t = next_elem.get_value()
      with self.assertRaises(errors.FailedPreconditionError):
        self.evaluate(elem_has_value_t)
      with self.assertRaises(errors.FailedPreconditionError):
        self.evaluate(elem_value_t)
      # Now we initialize the iterator.
      self.evaluate(iterator.initializer)
      # For each element of the dataset, assert that the optional evaluates to
      # the expected value.
      for _ in range(3):
        elem_has_value, elem_value = self.evaluate(
            [elem_has_value_t, elem_value_t])
        self.assertTrue(elem_has_value)
        self._assertElementValueEqual(np_value, elem_value)

      # After exhausting the iterator, `next_elem.has_value()` will evaluate to
      # false, and attempting to get the value will fail.
      for _ in range(2):
        self.assertFalse(self.evaluate(elem_has_value_t))
        with self.assertRaises(errors.InvalidArgumentError):
          self.evaluate(elem_value_t)

  def testFunctionBoundaries(self):
    @def_function.function
    def get_optional():
      x = constant_op.constant(1.0)
      opt = optional_ops.Optional.from_value(x)
      # TODO(skyewm): support returning Optionals from functions?
      return opt._variant_tensor

    # TODO(skyewm): support Optional arguments?
    @def_function.function
    def consume_optional(opt_tensor):
      value_structure = tensor_spec.TensorSpec([], dtypes.float32)
      opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
      return opt.get_value()

    opt_tensor = get_optional()
    val = consume_optional(opt_tensor)
    self.assertEqual(self.evaluate(val), 1.0)

  def testLimitedRetracing(self):
    trace_count = [0]

    @def_function.function
    def f(opt):
      trace_count[0] += 1
      return opt.get_value()

    opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0))
    opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0))

    for _ in range(10):
      self.assertEqual(self.evaluate(f(opt1)), 37.0)
      self.assertEqual(self.evaluate(f(opt2)), 42.0)
      self.assertEqual(trace_count[0], 1)