Example #1
0
 def graphRoundTrip(self, dataset, allow_stateful=False):
     """Converts a dataset to a graph and back."""
     graph = gen_dataset_ops.dataset_to_graph(dataset._variant_tensor,
                                              allow_stateful=allow_stateful)  # pylint: disable=protected-access
     return dataset_ops.from_variant(
         gen_experimental_dataset_ops.dataset_from_graph(graph),
         dataset.element_spec)
Example #2
0
  def testSerialization(self):
    with context.eager_mode():
      sentencepiece_model_file = (
          'tensorflow_text/python/ops/test_data/'
          'test_oss_model.model')
      model = gfile.GFile(sentencepiece_model_file, 'rb').read()
      sp = SentencepieceTokenizer(model)
      strings = ['hello', 'world']
      dataset = dataset_ops.Dataset.from_tensor_slices(strings)
      # Ensure we can map the tokenizer across the dataset.
      dataset = dataset.map(sp.tokenize)
      graph = dataset._as_serialized_graph()
      element_spec = dataset.element_spec
      dataset_graph_string = graph.numpy()
      expected = sp.tokenize(strings)

    # Reset the eager context to make sure that the serialized dataset graph
    # is self-contained.
    context._reset_context()

    with context.eager_mode():
      restored = dataset_ops.from_variant(
          gen_experimental_dataset_ops.dataset_from_graph(dataset_graph_string),
          element_spec)
      for i, result in enumerate(restored):
        self.assertAllEqual(result, expected[i])
Example #3
0
    def testImportedFunctionsRegistered(self):
        if test.is_built_with_gpu_support():
            self.skipTest(
                "Disabling this new test due to errors with cuda and rocm")

        with ops.Graph().as_default() as graph:
            x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
            ds = dataset_ops.from_variant(x,
                                          structure=(structure.TensorStructure(
                                              dtypes.int32, [])))
            y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32),
                          lambda p, q: p + q)

        graph_def = graph.as_graph_def()

        def fn_to_wrap(a):
            returned_elements = graph_def_importer.import_graph_def(
                graph_def, input_map={x.name: a}, return_elements=[y.name])
            return returned_elements[0]

        wrapped_fn = wrap_function.wrap_function(
            fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
        ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
        v = dataset_ops.to_variant(ds)
        self.evaluate(wrapped_fn(v))
Example #4
0
 def testRoundtripMap(self):
     dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x)
     variant = dataset_ops.to_variant(dataset)
     dataset = dataset_ops.from_variant(variant,
                                        dataset_ops.get_structure(dataset))
     self.assertDatasetProduces(dataset, [x * x for x in range(10)])
     self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
Example #5
0
 def testRoundtripRange(self):
     dataset = dataset_ops.Dataset.range(10)
     variant = dataset_ops.to_variant(dataset)
     dataset = dataset_ops.from_variant(variant,
                                        dataset_ops.get_structure(dataset))
     self.assertDatasetProduces(dataset, range(10))
     self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
Example #6
0
 def testRoundtripMap(self):
   dataset = dataset_ops.Dataset.range(10).map(lambda x: x*x)
   variant = dataset_ops.to_variant(dataset)
   dataset = dataset_ops.from_variant(variant,
                                      dataset_ops.get_structure(dataset))
   self.assertDatasetProduces(dataset, [x * x for x in range(10)])
   self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
Example #7
0
 def testRoundtripRange(self):
   dataset = dataset_ops.Dataset.range(10)
   variant = dataset_ops.to_variant(dataset)
   dataset = dataset_ops.from_variant(variant,
                                      dataset_ops.get_structure(dataset))
   self.assertDatasetProduces(dataset, range(10))
   self.assertEqual(self.evaluate(cardinality.cardinality(dataset)), 10)
Example #8
0
    def testImportedFunctionsRegistered(self):
        if test_util.is_gpu_available():
            self.skipTest('not a GPU test')
        with ops.Graph().as_default() as graph:
            x = array_ops.placeholder(dtypes.variant, shape=[], name='foo')
            ds = dataset_ops.from_variant(x,
                                          structure=(tensor_spec.TensorSpec(
                                              [], dtypes.int32)))
            y = ds.reduce(array_ops.zeros([], dtype=dtypes.int32),
                          lambda p, q: p + q)

        graph_def = graph.as_graph_def()

        def fn_to_wrap(a):
            returned_elements = graph_def_importer.import_graph_def(
                graph_def, input_map={x.name: a}, return_elements=[y.name])
            return returned_elements[0]

        wrapped_fn = wrap_function.wrap_function(
            fn_to_wrap, [tensor_spec.TensorSpec((), dtypes.variant)])
        ds = dataset_ops.Dataset.from_tensor_slices([10, 20])
        v = dataset_ops.to_variant(ds)
        self.evaluate(wrapped_fn(v))