コード例 #1
0
ファイル: test_base.py プロジェクト: MFChunga/poo
 def graphRoundTrip(self, dataset, allow_stateful=False):
     """Converts a dataset to a graph and back."""
     graph = gen_dataset_ops.dataset_to_graph(dataset._variant_tensor,
                                              allow_stateful=allow_stateful)  # pylint: disable=protected-access
     return dataset_ops.from_variant(
         gen_experimental_dataset_ops.dataset_from_graph(graph),
         dataset.element_spec)
コード例 #2
0
  def testSerialization(self):
    with context.eager_mode():
      sentencepiece_model_file = (
          'tensorflow_text/python/ops/test_data/'
          'test_oss_model.model')
      model = gfile.GFile(sentencepiece_model_file, 'rb').read()
      sp = SentencepieceTokenizer(model)
      strings = ['hello', 'world']
      dataset = dataset_ops.Dataset.from_tensor_slices(strings)
      # Ensure we can map the tokenizer across the dataset.
      dataset = dataset.map(sp.tokenize)
      graph = dataset._as_serialized_graph()
      element_spec = dataset.element_spec
      dataset_graph_string = graph.numpy()
      expected = sp.tokenize(strings)

    # Reset the eager context to make sure that the serialized dataset graph
    # is self-contained.
    context._reset_context()

    with context.eager_mode():
      restored = dataset_ops.from_variant(
          gen_experimental_dataset_ops.dataset_from_graph(dataset_graph_string),
          element_spec)
      for i, result in enumerate(restored):
        self.assertAllEqual(result, expected[i])
コード例 #3
0
 def __init__(self, graph_def, device, element_spec):
     self._elem_spec = element_spec
     with ops.device(device):
         variant_tensor = ged_ops.dataset_from_graph(graph_def)
     super(_RemoteDataset, self).__init__(variant_tensor)