def graphRoundTrip(self, dataset, allow_stateful=False): """Converts a dataset to a graph and back.""" graph = gen_dataset_ops.dataset_to_graph(dataset._variant_tensor, allow_stateful=allow_stateful) # pylint: disable=protected-access return dataset_ops.from_variant( gen_experimental_dataset_ops.dataset_from_graph(graph), dataset.element_spec)
def testSerialization(self): with context.eager_mode(): sentencepiece_model_file = ( 'tensorflow_text/python/ops/test_data/' 'test_oss_model.model') model = gfile.GFile(sentencepiece_model_file, 'rb').read() sp = SentencepieceTokenizer(model) strings = ['hello', 'world'] dataset = dataset_ops.Dataset.from_tensor_slices(strings) # Ensure we can map the tokenizer across the dataset. dataset = dataset.map(sp.tokenize) graph = dataset._as_serialized_graph() element_spec = dataset.element_spec dataset_graph_string = graph.numpy() expected = sp.tokenize(strings) # Reset the eager context to make sure that the serialized dataset graph # is self-contained. context._reset_context() with context.eager_mode(): restored = dataset_ops.from_variant( gen_experimental_dataset_ops.dataset_from_graph(dataset_graph_string), element_spec) for i, result in enumerate(restored): self.assertAllEqual(result, expected[i])
def __init__(self, graph_def, device, element_spec): self._elem_spec = element_spec with ops.device(device): variant_tensor = ged_ops.dataset_from_graph(graph_def) super(_RemoteDataset, self).__init__(variant_tensor)