Esempio n. 1
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertEqual(num_inputs, len(dataset_fn()._inputs()))

    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEqual(0, len(dataset_fn._inputs()))

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset.make_one_shot_iterator()
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_deprecated_v1
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)
Esempio n. 2
0
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
    @test_util.run_in_graph_and_eager_modes
    def testFromValue(self):
        opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual(37.0, self.evaluate(opt.get_value()))

    @test_util.run_in_graph_and_eager_modes
    def testFromStructuredValue(self):
        opt = optional_ops.Optional.from_value({
            "a":
            constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        })
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual({
            "a": 37.0,
            "b": ([b"Foo"], b"Bar")
        }, self.evaluate(opt.get_value()))

    @test_util.run_in_graph_and_eager_modes
    def testFromSparseTensor(self):
        st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                               values=np.array([0],
                                                               dtype=np.int64),
                                               dense_shape=np.array([1]))
        st_1 = sparse_tensor.SparseTensorValue(
            indices=np.array([[0, 0], [1, 1]]),
            values=np.array([-1., 1.], dtype=np.float32),
            dense_shape=np.array([2, 2]))
        opt = optional_ops.Optional.from_value((st_0, st_1))
        self.assertTrue(self.evaluate(opt.has_value()))
        val_0, val_1 = opt.get_value()
        for expected, actual in [(st_0, val_0), (st_1, val_1)]:
            self.assertAllEqual(expected.indices,
                                self.evaluate(actual.indices))
            self.assertAllEqual(expected.values, self.evaluate(actual.values))
            self.assertAllEqual(expected.dense_shape,
                                self.evaluate(actual.dense_shape))

    @test_util.run_in_graph_and_eager_modes
    def testFromNone(self):
        value_structure = structure.TensorStructure(dtypes.float32, [])
        opt = optional_ops.Optional.none_from_structure(value_structure)
        self.assertTrue(
            opt.value_structure.is_compatible_with(value_structure))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.float32, [1])))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.int32, [])))
        self.assertFalse(self.evaluate(opt.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(opt.get_value())

    @test_util.run_in_graph_and_eager_modes
    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))

    def _assertElementValueEqual(self, expected, actual):
        if isinstance(expected, dict):
            self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
            for k in expected.keys():
                self._assertElementValueEqual(expected[k], actual[k])
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.assertAllEqual(expected.indices, actual.indices)
            self.assertAllEqual(expected.values, actual.values)
            self.assertAllEqual(expected.dense_shape, actual.dense_shape)
        else:
            self.assertAllEqual(expected, actual)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testOptionalStructure(self, tf_value_fn, expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            expected_value_structure.is_compatible_with(opt.value_structure))
        self.assertTrue(
            opt.value_structure.is_compatible_with(expected_value_structure))

        opt_structure = structure.Structure.from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(opt_structure.is_compatible_with(opt_structure))
        self.assertTrue(
            opt_structure._value_structure.is_compatible_with(
                expected_value_structure))
        self.assertEqual([dtypes.variant], opt_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.Structure.from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self.assertEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self.assertEqual(self.evaluate(tf_value),
                             self.evaluate(round_trip_opt.get_value()))

    @parameterized.named_parameters(
        ("Tensor", np.array([1, 2, 3], dtype=np.int32),
         lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
        ("SparseTensor",
         sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                         values=np.array([-1., 1.],
                                                         dtype=np.float32),
                                         dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                            values=[37.0, 42.0],
                                            dense_shape=[2, 2]), False),
        ("Nest", {
            "a":
            np.array([1, 2, 3], dtype=np.int32),
            "b":
            sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                            values=np.array([-1., 1.],
                                                            dtype=np.float32),
                                            dense_shape=[2, 2])
        }, lambda: {
            "a":
            constant_op.constant([4, 5, 6], dtype=dtypes.int32),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                       values=[37.0, 42.0],
                                       dense_shape=[2, 2])
        }, False),
    )
    def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                      works_on_gpu):
        if not works_on_gpu and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
        iterator = ds.make_initializable_iterator()
        next_elem = iterator_ops.get_next_as_optional(iterator)
        self.assertIsInstance(next_elem, optional_ops.Optional)
        self.assertTrue(
            next_elem.value_structure.is_compatible_with(
                structure.Structure.from_value(tf_value_fn())))
        elem_has_value_t = next_elem.has_value()
        elem_value_t = next_elem.get_value()
        with self.cached_session() as sess:
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError.
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                sess.run(elem_value_t)

            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            self.evaluate(iterator.initializer)
            for _ in range(3):
                elem_has_value, elem_value = sess.run(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self._assertElementValueEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    sess.run(elem_value_t)
Esempio n. 3
0
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testAsSerializedGraph(self):
        dataset = dataset_ops.Dataset.range(10)
        graph = graph_pb2.GraphDef().FromString(
            self.evaluate(dataset._as_serialized_graph()))
        self.assertTrue(any([node.op != "RangeDataset"
                             for node in graph.node]))

    def testAsFunctionWithMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        fn = original_dataset._trace_variant_creation()
        variant = fn()

        revived_dataset = _RevivedDataset(variant,
                                          original_dataset._element_structure)
        self.assertDatasetProduces(revived_dataset, range(0, 10, 2))

    def testAsFunctionWithMapInFlatMap(self):
        if not context.executing_eagerly():
            self.skipTest("Only works executing eagerly")
        original_dataset = dataset_ops.Dataset.range(5).flat_map(
            lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
        fn = original_dataset._trace_variant_creation()
        variant = fn()

        revived_dataset = _RevivedDataset(variant,
                                          original_dataset._element_structure)
        self.assertDatasetProduces(revived_dataset, list(original_dataset))

    @staticmethod
    def make_apply_fn(dataset):
        def apply_fn(dataset):
            def _apply_fn(dataset):
                return dataset.cache()

            return dataset.apply(_apply_fn)

        return apply_fn

    @staticmethod
    def make_gen():
        def gen():
            yield 42

        return gen

    @staticmethod
    def make_interleave_fn(dataset, num_parallel_calls=None):
        def interleave_fn(dataset):
            return dataset.interleave(lambda x: dataset_ops.Dataset.range(0),
                                      cycle_length=2,
                                      num_parallel_calls=num_parallel_calls)

        return interleave_fn

    @parameterized.named_parameters(
        ("FixedLengthRecord",
         lambda: readers.FixedLengthRecordDataset("", 42)),
        ("FromGenerator", lambda: dataset_ops.Dataset.from_generator(
            DatasetTest.make_gen(), dtypes.int32), 1),
        ("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
        ("Range", lambda: dataset_ops.Dataset.range(10)),
        ("TextLine", lambda: readers.TextLineDataset("")),
        ("TFRecord", lambda: readers.TFRecordDataset(""), 1),
    )
    def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
        self.assertLen(dataset_fn()._inputs(), num_inputs)

    @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
    def testDatasetComplexSourceInputs(self):
        dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
            sparse_tensor.SparseTensor(indices=np.array([[0, 0], [1, 0],
                                                         [2, 0]]),
                                       values=np.array([0, 0, 0]),
                                       dense_shape=np.array([3, 1])))
        self.assertEmpty(dataset_fn._inputs())

    @parameterized.named_parameters(
        ("Batch", lambda x: x.batch(10), lambda: dataset_ops.Dataset.range(0)),
        ("Cache", lambda x: x.cache(), lambda: dataset_ops.Dataset.range(0)),
        ("Filter", lambda x: x.filter(lambda x: True),
         lambda: dataset_ops.Dataset.range(0)),
        ("FlatMap",
         lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
         lambda: dataset_ops.Dataset.range(0)),
        ("Map", lambda x: x.map(lambda x: x),
         lambda: dataset_ops.Dataset.range(0)),
        ("PaddedBatch", lambda x: x.padded_batch(10, []),
         lambda: dataset_ops.Dataset.range(0)),
        ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
         lambda: dataset_ops.Dataset.range(0)),
        ("Repeat", lambda x: x.repeat(), lambda: dataset_ops.Dataset.range(0)),
        ("Shuffle", lambda x: x.shuffle(10),
         lambda: dataset_ops.Dataset.range(0)),
        ("Skip", lambda x: x.skip(1), lambda: dataset_ops.Dataset.range(0)),
        ("Take", lambda x: x.take(1), lambda: dataset_ops.Dataset.range(0)),
        ("Window", lambda x: x.window(10),
         lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
        input_dataset = input_dataset_fn()
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    def testUnaryTransformationInputsApply(self):
        input_dataset = dataset_ops.Dataset.range(0)
        dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("ParallelInterleave", [lambda: dataset_ops.Dataset.range(0), 2
                                ], lambda: dataset_ops.Dataset.range(0)),
        ("Interleave", [lambda: dataset_ops.Dataset.range(0), None
                        ], lambda: dataset_ops.Dataset.range(0)),
    )
    def testUnaryTransformationInputsWithInterleaveFn(self, interleave_fn_args,
                                                      input_dataset_fn):
        input_dataset = input_dataset_fn()
        dataset_fn = self.make_interleave_fn(*interleave_fn_args)
        self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())

    @parameterized.named_parameters(
        ("Concatenate", lambda x, y: x.concatenate(y),
         lambda: dataset_ops.Dataset.range(0),
         lambda: dataset_ops.Dataset.range(1)))
    def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
        input1 = input1_fn()
        input2 = input2_fn()
        self.assertEqual([input1, input2],
                         dataset_fn(input1, input2)._inputs())

    @parameterized.named_parameters(
        ("ZipOne", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0))),
        ("ZipNest", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0),
          (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
        ("ZipTuple", dataset_ops.Dataset.zip, lambda:
         (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))),
    )
    def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
        input_datasets = input_datasets_fn()
        self.assertEqual(nest.flatten(input_datasets),
                         dataset_fn(input_datasets)._inputs())

    def testFunctions(self):
        dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
        self.assertLen(dataset._functions(), 1)

    def testCollectInputs(self):
        ds1 = dataset_ops.Dataset.range(0)
        ds2 = ds1.concatenate(ds1)
        ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))

        inputs = []
        queue = [ds3]
        while queue:
            ds = queue[0]
            queue = queue[1:]
            queue.extend(ds._inputs())
            inputs.append(ds)

        self.assertEqual(5, inputs.count(ds1))
        self.assertEqual(2, inputs.count(ds2))
        self.assertEqual(1, inputs.count(ds3))

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[1]), structure.SparseTensorStructure(
                dtypes.int32, [1])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        },
         structure.NestedStructure({
             "a":
             structure.TensorStructure(dtypes.float32, []),
             "b": (structure.TensorStructure(dtypes.string, [1]),
                   structure.TensorStructure(dtypes.string, []))
         })),
        ("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
            constant_op.constant([1, 2, 3])),
         dataset_ops.DatasetStructure(
             structure.TensorStructure(dtypes.int32, []))),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testDatasetStructure(self, tf_value_fn, expected_element_structure):
        dataset = dataset_ops.Dataset.from_tensors(0).map(
            lambda _: tf_value_fn())
        dataset_structure = structure.Structure.from_value(dataset)
        self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)

        # TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
        # the element structure.
        self.assertTrue(
            expected_element_structure.is_compatible_with(
                dataset_structure._element_structure))
        self.assertTrue(
            dataset_structure._element_structure.is_compatible_with(
                expected_element_structure))

        self.assertEqual([dtypes.variant], dataset_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()],
                         dataset_structure._flat_shapes)

        # Assert that the `Dataset` survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_dataset = dataset_structure._from_tensor_list(
            dataset_structure._to_tensor_list(dataset))

        value = tf_value_fn()

        if isinstance(value, dataset_ops.Dataset):
            self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
        elif isinstance(value, optional_ops.Optional):
            self.assertDatasetProduces(
                round_trip_dataset.map(lambda opt: opt.get_value()),
                [self.evaluate(value.get_value())],
                requires_initialization=True)
        else:
            self.assertDatasetProduces(round_trip_dataset,
                                       [self.evaluate(tf_value_fn())],
                                       requires_initialization=True)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShot(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorOneShotSimple(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with test.mock.patch.object(logging, "warning") as mock_log:
                _ = dataset_ops.make_one_shot_iterator(dataset)
                self.assertRegexpMatches(
                    str(mock_log.call_args),
                    "Please ensure that all datasets in the "
                    "pipeline are created in the same graph as the iterator.")

    @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
    def testSkipEagerSameGraphErrorInitializable(self):
        dataset = dataset_ops.Dataset.range(10)
        with ops.Graph().as_default():
            with self.assertRaisesRegexp(ValueError,
                                         "must be from the same graph"):
                dataset = dataset.batch(2)

    @parameterized.named_parameters(
        ("Async", context.ASYNC),
        ("Sync", context.SYNC),
    )
    def testDatasetEagerIteration(self, execution_mode):
        with context.eager_mode(), context.execution_mode(execution_mode):
            val = 0
            dataset = dataset_ops.Dataset.range(10)
            for foo in dataset:
                self.assertEqual(val, foo.numpy())
                val += 1

    def testDatasetAsFunctionArgument(self):
        @def_function.function
        def _uses_dataset(d):
            accumulator = array_ops.zeros([], dtype=dtypes.int64)
            for value in d:
                accumulator += value
            return accumulator

        first_dataset = dataset_ops.Dataset.range(10)
        self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
        second_dataset = dataset_ops.Dataset.range(11)
        self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
        first_concrete = _uses_dataset.get_concrete_function(first_dataset)
        self.skipTest((
            "Not currently working: functions treat Datasets as opaque Python "
            "objects"))
        # The dataset should not be a captured input
        self.assertEmpty(first_concrete.graph.captures)
        # The two datasets have the same structure and so should re-use a trace.
        self.assertIs(first_concrete,
                      _uses_dataset.get_concrete_function(second_dataset))
        # With a different structure we should use a different trace.
        self.assertIsNot(
            first_concrete,
            _uses_dataset.get_concrete_function(
                dataset_ops.Dataset.zip((first_dataset, second_dataset))))
Esempio n. 4
0
class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
    def testFromValue(self):
        opt = optional_ops.Optional.from_value(constant_op.constant(37.0))
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual(37.0, self.evaluate(opt.get_value()))

    def testFromStructuredValue(self):
        opt = optional_ops.Optional.from_value({
            "a":
            constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        })
        self.assertTrue(self.evaluate(opt.has_value()))
        self.assertEqual({
            "a": 37.0,
            "b": ([b"Foo"], b"Bar")
        }, self.evaluate(opt.get_value()))

    def testFromSparseTensor(self):
        st_0 = sparse_tensor.SparseTensorValue(indices=np.array([[0]]),
                                               values=np.array([0],
                                                               dtype=np.int64),
                                               dense_shape=np.array([1]))
        st_1 = sparse_tensor.SparseTensorValue(
            indices=np.array([[0, 0], [1, 1]]),
            values=np.array([-1., 1.], dtype=np.float32),
            dense_shape=np.array([2, 2]))
        opt = optional_ops.Optional.from_value((st_0, st_1))
        self.assertTrue(self.evaluate(opt.has_value()))
        val_0, val_1 = opt.get_value()
        for expected, actual in [(st_0, val_0), (st_1, val_1)]:
            self.assertAllEqual(expected.indices,
                                self.evaluate(actual.indices))
            self.assertAllEqual(expected.values, self.evaluate(actual.values))
            self.assertAllEqual(expected.dense_shape,
                                self.evaluate(actual.dense_shape))

    def testFromNone(self):
        value_structure = structure.TensorStructure(dtypes.float32, [])
        opt = optional_ops.Optional.none_from_structure(value_structure)
        self.assertTrue(
            opt.value_structure.is_compatible_with(value_structure))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.float32, [1])))
        self.assertFalse(
            opt.value_structure.is_compatible_with(
                structure.TensorStructure(dtypes.int32, [])))
        self.assertFalse(self.evaluate(opt.has_value()))
        with self.assertRaises(errors.InvalidArgumentError):
            self.evaluate(opt.get_value())

    def testAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt1 = optional_ops.Optional.from_value((1.0, 2.0))
                opt2 = optional_ops.Optional.from_value((3.0, 4.0))

                add_tensor = math_ops.add_n(
                    [opt1._variant_tensor, opt2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt1.value_structure)
                self.assertAllEqual(self.evaluate(add_opt.get_value()),
                                    (4.0, 6.0))

                # Without value
                opt_none1 = optional_ops.Optional.none_from_structure(
                    opt1.value_structure)
                opt_none2 = optional_ops.Optional.none_from_structure(
                    opt2.value_structure)
                add_tensor = math_ops.add_n(
                    [opt_none1._variant_tensor, opt_none2._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt_none1.value_structure)
                self.assertFalse(self.evaluate(add_opt.has_value()))

    def testNestedAddN(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value([1, 2.0])
                opt2 = optional_ops.Optional.from_value([3, 4.0])
                opt3 = optional_ops.Optional.from_value(
                    (5.0, opt1._variant_tensor))
                opt4 = optional_ops.Optional.from_value(
                    (6.0, opt2._variant_tensor))

                add_tensor = math_ops.add_n(
                    [opt3._variant_tensor, opt4._variant_tensor])
                add_opt = optional_ops._OptionalImpl(add_tensor,
                                                     opt3.value_structure)
                self.assertEqual(self.evaluate(add_opt.get_value()[0]), 11.0)

                inner_add_opt = optional_ops._OptionalImpl(
                    add_opt.get_value()[1], opt1.value_structure)
                self.assertAllEqual(inner_add_opt.get_value(), [4, 6.0])

    def testZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                # With value
                opt = optional_ops.Optional.from_value((1.0, 2.0))
                zeros_tensor = array_ops.zeros_like(opt._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt.value_structure)
                self.assertAllEqual(self.evaluate(zeros_opt.get_value()),
                                    (0.0, 0.0))

                # Without value
                opt_none = optional_ops.Optional.none_from_structure(
                    opt.value_structure)
                zeros_tensor = array_ops.zeros_like(opt_none._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(
                    zeros_tensor, opt_none.value_structure)
                self.assertFalse(self.evaluate(zeros_opt.has_value()))

    def testNestedZerosLike(self):
        devices = ["/cpu:0"]
        if test_util.is_gpu_available():
            devices.append("/gpu:0")
        for device in devices:
            with ops.device(device):
                opt1 = optional_ops.Optional.from_value(1.0)
                opt2 = optional_ops.Optional.from_value(opt1._variant_tensor)

                zeros_tensor = array_ops.zeros_like(opt2._variant_tensor)
                zeros_opt = optional_ops._OptionalImpl(zeros_tensor,
                                                       opt2.value_structure)
                inner_zeros_opt = optional_ops._OptionalImpl(
                    zeros_opt.get_value(), opt1.value_structure)
                self.assertEqual(self.evaluate(inner_zeros_opt.get_value()),
                                 0.0)

    def testCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))

        with ops.device("/gpu:0"):
            gpu_optional_with_value = optional_ops._OptionalImpl(
                array_ops.identity(optional_with_value._variant_tensor),
                optional_with_value.value_structure)
            gpu_optional_none = optional_ops._OptionalImpl(
                array_ops.identity(optional_none._variant_tensor),
                optional_none.value_structure)

            gpu_optional_with_value_has_value = gpu_optional_with_value.has_value(
            )
            gpu_optional_with_value_values = gpu_optional_with_value.get_value(
            )

            gpu_optional_none_has_value = gpu_optional_none.has_value()

        self.assertTrue(self.evaluate(gpu_optional_with_value_has_value))
        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(gpu_optional_with_value_values))
        self.assertFalse(self.evaluate(gpu_optional_none_has_value))

    def testNestedCopyToGPU(self):
        if not test_util.is_gpu_available():
            self.skipTest("No GPU available")

        with ops.device("/cpu:0"):
            optional_with_value = optional_ops.Optional.from_value(
                (constant_op.constant(37.0), constant_op.constant("Foo"),
                 constant_op.constant(42)))
            optional_none = optional_ops.Optional.none_from_structure(
                structure.TensorStructure(dtypes.float32, []))
            nested_optional = optional_ops.Optional.from_value(
                (optional_with_value._variant_tensor,
                 optional_none._variant_tensor, 1.0))

        with ops.device("/gpu:0"):
            gpu_nested_optional = optional_ops._OptionalImpl(
                array_ops.identity(nested_optional._variant_tensor),
                nested_optional.value_structure)

            gpu_nested_optional_has_value = gpu_nested_optional.has_value()
            gpu_nested_optional_values = gpu_nested_optional.get_value()

        self.assertTrue(self.evaluate(gpu_nested_optional_has_value))

        inner_with_value = optional_ops._OptionalImpl(
            gpu_nested_optional_values[0], optional_with_value.value_structure)

        inner_none = optional_ops._OptionalImpl(gpu_nested_optional_values[1],
                                                optional_none.value_structure)

        self.assertEqual((37.0, b"Foo", 42),
                         self.evaluate(inner_with_value.get_value()))
        self.assertFalse(self.evaluate(inner_none.has_value()))
        self.assertEqual(1.0, self.evaluate(gpu_nested_optional_values[2]))

    def _assertElementValueEqual(self, expected, actual):
        if isinstance(expected, dict):
            self.assertItemsEqual(list(expected.keys()), list(actual.keys()))
            for k in expected.keys():
                self._assertElementValueEqual(expected[k], actual[k])
        elif isinstance(expected, sparse_tensor.SparseTensorValue):
            self.assertAllEqual(expected.indices, actual.indices)
            self.assertAllEqual(expected.values, actual.values)
            self.assertAllEqual(expected.dense_shape, actual.dense_shape)
        else:
            self.assertAllEqual(expected, actual)

    # pylint: disable=g-long-lambda
    @parameterized.named_parameters(
        ("Tensor", lambda: constant_op.constant(37.0),
         structure.TensorStructure(dtypes.float32, [])),
        ("SparseTensor", lambda: sparse_tensor.SparseTensor(
            indices=[[0, 1]],
            values=constant_op.constant([0], dtype=dtypes.int32),
            dense_shape=[10, 10]),
         structure.SparseTensorStructure(dtypes.int32, [10, 10])),
        ("Nest", lambda: {
            "a": constant_op.constant(37.0),
            "b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
        }, {
            "a":
            structure.TensorStructure(dtypes.float32, []),
            "b": (structure.TensorStructure(dtypes.string, [1]),
                  structure.TensorStructure(dtypes.string, []))
        }),
        ("Optional", lambda: optional_ops.Optional.from_value(37.0),
         optional_ops.OptionalStructure(
             structure.TensorStructure(dtypes.float32, []))),
    )
    def testOptionalStructure(self, tf_value_fn, expected_value_structure):
        tf_value = tf_value_fn()
        opt = optional_ops.Optional.from_value(tf_value)

        self.assertTrue(
            structure.are_compatible(opt.value_structure,
                                     expected_value_structure))

        opt_structure = structure.type_spec_from_value(opt)
        self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
        self.assertTrue(structure.are_compatible(opt_structure, opt_structure))
        self.assertTrue(
            structure.are_compatible(opt_structure._value_structure,
                                     expected_value_structure))
        self.assertEqual([dtypes.variant], opt_structure._flat_types)
        self.assertEqual([tensor_shape.scalar()], opt_structure._flat_shapes)

        # All OptionalStructure objects are not compatible with a non-optional
        # value.
        non_optional_structure = structure.type_spec_from_value(
            constant_op.constant(42.0))
        self.assertFalse(
            opt_structure.is_compatible_with(non_optional_structure))

        # Assert that the optional survives a round-trip via _from_tensor_list()
        # and _to_tensor_list().
        round_trip_opt = opt_structure._from_tensor_list(
            opt_structure._to_tensor_list(opt))
        if isinstance(tf_value, optional_ops.Optional):
            self._assertElementValueEqual(
                self.evaluate(tf_value.get_value()),
                self.evaluate(round_trip_opt.get_value().get_value()))
        else:
            self._assertElementValueEqual(
                self.evaluate(tf_value),
                self.evaluate(round_trip_opt.get_value()))

    @parameterized.named_parameters(
        ("Tensor", np.array([1, 2, 3], dtype=np.int32),
         lambda: constant_op.constant([4, 5, 6], dtype=dtypes.int32), True),
        ("SparseTensor",
         sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                         values=np.array([-1., 1.],
                                                         dtype=np.float32),
                                         dense_shape=[2, 2]),
         lambda: sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                            values=[37.0, 42.0],
                                            dense_shape=[2, 2]), False),
        ("Nest", {
            "a":
            np.array([1, 2, 3], dtype=np.int32),
            "b":
            sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 1]],
                                            values=np.array([-1., 1.],
                                                            dtype=np.float32),
                                            dense_shape=[2, 2])
        }, lambda: {
            "a":
            constant_op.constant([4, 5, 6], dtype=dtypes.int32),
            "b":
            sparse_tensor.SparseTensor(indices=[[0, 1], [1, 0]],
                                       values=[37.0, 42.0],
                                       dense_shape=[2, 2])
        }, False),
    )
    def testIteratorGetNextAsOptional(self, np_value, tf_value_fn,
                                      works_on_gpu):
        if not works_on_gpu and test.is_gpu_available():
            self.skipTest("Test case not yet supported on GPU.")
        ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)

        if context.executing_eagerly():
            iterator = dataset_ops.make_one_shot_iterator(ds)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertIsInstance(next_elem, optional_ops.Optional)
                self.assertTrue(
                    structure.are_compatible(
                        next_elem.value_structure,
                        structure.type_spec_from_value(tf_value_fn())))
                self.assertTrue(next_elem.has_value())
                self._assertElementValueEqual(np_value, next_elem.get_value())
            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                next_elem = iterator_ops.get_next_as_optional(iterator)
                self.assertFalse(self.evaluate(next_elem.has_value()))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(next_elem.get_value())
        else:
            iterator = dataset_ops.make_initializable_iterator(ds)
            next_elem = iterator_ops.get_next_as_optional(iterator)
            self.assertIsInstance(next_elem, optional_ops.Optional)
            self.assertTrue(
                structure.are_compatible(
                    next_elem.value_structure,
                    structure.type_spec_from_value(tf_value_fn())))
            # Before initializing the iterator, evaluating the optional fails with
            # a FailedPreconditionError. This is only relevant in graph mode.
            elem_has_value_t = next_elem.has_value()
            elem_value_t = next_elem.get_value()
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_has_value_t)
            with self.assertRaises(errors.FailedPreconditionError):
                self.evaluate(elem_value_t)
            # Now we initialize the iterator.
            self.evaluate(iterator.initializer)
            # For each element of the dataset, assert that the optional evaluates to
            # the expected value.
            for _ in range(3):
                elem_has_value, elem_value = self.evaluate(
                    [elem_has_value_t, elem_value_t])
                self.assertTrue(elem_has_value)
                self._assertElementValueEqual(np_value, elem_value)

            # After exhausting the iterator, `next_elem.has_value()` will evaluate to
            # false, and attempting to get the value will fail.
            for _ in range(2):
                self.assertFalse(self.evaluate(elem_has_value_t))
                with self.assertRaises(errors.InvalidArgumentError):
                    self.evaluate(elem_value_t)

    def testFunctionBoundaries(self):
        @def_function.function
        def get_optional():
            x = constant_op.constant(1.0)
            opt = optional_ops.Optional.from_value(x)
            # TODO(skyewm): support returning Optionals from functions?
            return opt._variant_tensor

        # TODO(skyewm): support Optional arguments?
        @def_function.function
        def consume_optional(opt_tensor):
            value_structure = structure.TensorStructure(dtypes.float32, [])
            opt = optional_ops._OptionalImpl(opt_tensor, value_structure)
            return opt.get_value()

        opt_tensor = get_optional()
        val = consume_optional(opt_tensor)
        self.assertEqual(self.evaluate(val), 1.0)

    def testLimitedRetracing(self):
        trace_count = [0]

        @def_function.function
        def f(opt):
            trace_count[0] += 1
            return opt.get_value()

        opt1 = optional_ops.Optional.from_value(constant_op.constant(37.0))
        opt2 = optional_ops.Optional.from_value(constant_op.constant(42.0))

        for _ in range(10):
            self.assertEqual(self.evaluate(f(opt1)), 37.0)
            self.assertEqual(self.evaluate(f(opt2)), 42.0)
            self.assertEqual(trace_count[0], 1)