コード例 #1
0
    def testStringSplitWithDelimiter(self):
        strings = ["hello|world", "hello world"]

        with self.cached_session() as sess:
            self.assertRaises(ValueError,
                              string_ops.string_split,
                              strings,
                              delimiter=["|", ""])

            self.assertRaises(ValueError,
                              string_ops.string_split,
                              strings,
                              delimiter=["a"])

            tokens = string_ops.string_split(strings, delimiter="|")
            indices, values, shape = sess.run(tokens)
            self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0]])
            self.assertAllEqual(values, [b"hello", b"world", b"hello world"])
            self.assertAllEqual(shape, [2, 2])

            tokens = string_ops.string_split(strings, delimiter="| ")
            indices, values, shape = sess.run(tokens)
            self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
            self.assertAllEqual(values,
                                [b"hello", b"world", b"hello", b"world"])
            self.assertAllEqual(shape, [2, 2])
コード例 #2
0
  def testCaptureHashTable(self):
    # NOTE(mrry): We must use the V2 variants of `HashTable`
    # etc. because these produce a `tf.resource`-typed output that is
    # compatible with the in-graph function implementation.
    default_val = -1
    keys = constant_op.constant(["brain", "salad", "surgery"])
    values = constant_op.constant([0, 1, 2], dtypes.int64)
    table = lookup_ops.HashTable(
        lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

    input_sentences = dataset_ops.Dataset.from_tensor_slices(
        ["brain brain tank salad surgery", "surgery brain"])

    iterator = (input_sentences
                .map(lambda x: string_ops.string_split([x]).values)
                .map(table.lookup)
                .make_initializable_iterator())
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      sess.run(table.initializer)
      sess.run(init_op)
      sess.run(get_next)
      sess.run(get_next)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
コード例 #3
0
def string_split(source,
                 sep=None,
                 skip_empty=True,
                 delimiter=None,
                 result_type="SparseTensor"):  # pylint: disable=invalid-name
    """Split elements of `source` based on `delimiter`.

  Let N be the size of `source` (typically N will be the batch size). Split each
  element of `source` based on `delimiter` and return a `SparseTensor`
  or `RaggedTensor` containing the split tokens. Empty tokens are ignored.

  If `sep` is an empty string, each element of the `source` is split
  into individual strings, each containing one byte. (This includes splitting
  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
  treated as a set of delimiters with each considered a potential split point.

  Examples:

  ```python
  >>> tf.strings.split(['hello world', 'a b c'])
  tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
                  values=['hello', 'world', 'a', 'b', 'c']
                  dense_shape=[2, 3])

  >>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
  <tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
  ```

  Args:
    source: `1-D` string `Tensor`, the strings to split.
    sep: `0-D` string `Tensor`, the delimiter character, the string should
      be length 0 or 1. Default is ' '.
    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
    delimiter: deprecated alias for `sep`.
    result_type: The tensor type for the result: one of `"RaggedTensor"` or
      `"SparseTensor"`.

  Raises:
    ValueError: If delimiter is not a string.

  Returns:
    A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
    to the delimiter.  The first column of the indices corresponds to the row
    in `source` and the second column corresponds to the index of the split
    component in this row.
  """
    sparse_result = string_ops.string_split(source,
                                            sep=sep,
                                            skip_empty=skip_empty,
                                            delimiter=delimiter)
    if result_type == "SparseTensor":
        return sparse_result
    elif result_type == "RaggedTensor":
        return ragged_tensor.RaggedTensor.from_value_rowids(
            values=sparse_result.values,
            value_rowids=sparse_result.indices[:, 0],
            nrows=sparse_result.dense_shape[0])
    else:
        raise ValueError(
            "result_type must be 'RaggedTensor' or 'SparseTensor'.")
コード例 #4
0
    def testCaptureHashTableInSharedIterator(self):
        worker, _ = test_util.create_local_cluster(1, 1)

        # NOTE(mrry): We must use the V2 variants of `HashTable`
        # etc. because these produce a `tf.resource`-typed output that is
        # compatible with the in-graph function implementation.
        default_val = -1
        keys = constant_op.constant(["brain", "salad", "surgery"])
        values = constant_op.constant([0, 1, 2], dtypes.int64)
        table = lookup_ops.StaticHashTableV1(
            lookup_ops.KeyValueTensorInitializer(keys, values), default_val)

        input_sentences = dataset_ops.Dataset.from_tensor_slices(
            ["brain brain tank salad surgery", "surgery brain"])

        dataset = input_sentences.map(
            lambda x: string_ops.string_split([x]).values).map(table.lookup)
        iterator = dataset_ops.make_initializable_iterator(
            dataset, shared_name="shared_iterator")
        init_op = iterator.initializer
        get_next = iterator.get_next()

        with session.Session(worker[0].target) as sess:
            sess.run(table.initializer)
            sess.run(init_op)
            self.assertAllEqual([0, 0, -1, 1, 2], sess.run(get_next))

        with session.Session(worker[0].target) as sess:
            self.assertAllEqual([2, 0], sess.run(get_next))
            with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next)
コード例 #5
0
  def testCaptureHashTableInSharedIterator(self):
    worker, _ = test_util.create_local_cluster(1, 1)

    # NOTE(mrry): We must use the V2 variants of `HashTable`
    # etc. because these produce a `tf.resource`-typed output that is
    # compatible with the in-graph function implementation.
    default_val = -1
    keys = constant_op.constant(["brain", "salad", "surgery"])
    values = constant_op.constant([0, 1, 2], dtypes.int64)
    table = lookup_ops.HashTable(
        lookup_ops.KeyValueTensorInitializer(keys, values),
        default_val,
        shared_name="shared_table")

    input_sentences = dataset_ops.Dataset.from_tensor_slices(
        ["brain brain tank salad surgery", "surgery brain"])

    iterator = (
        input_sentences.map(lambda x: string_ops.string_split([x]).values).map(
            table.lookup)
        .make_initializable_iterator(shared_name="shared_iterator"))
    init_op = iterator.initializer
    get_next = iterator.get_next()

    with session.Session(worker[0].target) as sess:
      self.evaluate(table.initializer)
      self.evaluate(init_op)
      self.assertAllEqual([0, 0, -1, 1, 2], self.evaluate(get_next))

    with session.Session(worker[0].target) as sess:
      self.assertAllEqual([2, 0], self.evaluate(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
コード例 #6
0
  def testStringSplit(self):
    strings = ["pigs on the wing", "animals"]

    with self.cached_session() as sess:
      tokens = string_ops.string_split(strings)
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [1, 0]])
      self.assertAllEqual(values, [b"pigs", b"on", b"the", b"wing", b"animals"])
      self.assertAllEqual(shape, [2, 4])
コード例 #7
0
  def testStringSplitWithNoSkipEmpty(self):
    strings = ["#a", "b#", "#c#"]

    with self.cached_session() as sess:
      tokens = string_ops.string_split(strings, "#", skip_empty=False)
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(indices, [[0, 0], [0, 1],
                                    [1, 0], [1, 1],
                                    [2, 0], [2, 1], [2, 2]])
      self.assertAllEqual(values, [b"", b"a", b"b", b"", b"", b"c", b""])
      self.assertAllEqual(shape, [3, 3])

    with self.cached_session() as sess:
      tokens = string_ops.string_split(strings, "#")
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(values, [b"a", b"b", b"c"])
      self.assertAllEqual(indices, [[0, 0], [1, 0], [2, 0]])
      self.assertAllEqual(shape, [3, 1])
コード例 #8
0
def string_split(source, sep=None, skip_empty=True, delimiter=None,
                 result_type="SparseTensor", name=None):  # pylint: disable=invalid-name
  """Split elements of `source` based on `delimiter`.

  Let N be the size of `source` (typically N will be the batch size). Split each
  element of `source` based on `delimiter` and return a `SparseTensor`
  or `RaggedTensor` containing the split tokens. Empty tokens are ignored.

  If `sep` is an empty string, each element of the `source` is split
  into individual strings, each containing one byte. (This includes splitting
  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
  treated as a set of delimiters with each considered a potential split point.

  Examples:

  ```python
  >>> tf.strings.split(['hello world', 'a b c'])
  tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
                  values=['hello', 'world', 'a', 'b', 'c']
                  dense_shape=[2, 3])

  >>> tf.strings.split(['hello world', 'a b c'], result_type="RaggedTensor")
  <tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
  ```

  Args:
    source: `1-D` string `Tensor`, the strings to split.
    sep: `0-D` string `Tensor`, the delimiter character, the string should
      be length 0 or 1. Default is ' '.
    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
    delimiter: deprecated alias for `sep`.
    result_type: The tensor type for the result: one of `"RaggedTensor"` or
      `"SparseTensor"`.
    name: A name for the operation (optional).

  Raises:
    ValueError: If delimiter is not a string.

  Returns:
    A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
    to the delimiter.  The first column of the indices corresponds to the row
    in `source` and the second column corresponds to the index of the split
    component in this row.
  """
  with ops.name_scope(name, "StringSplit", [source]):
    sparse_result = string_ops.string_split(
        source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
    if result_type == "SparseTensor":
      return sparse_result
    elif result_type == "RaggedTensor":
      return ragged_tensor.RaggedTensor.from_value_rowids(
          values=sparse_result.values,
          value_rowids=sparse_result.indices[:, 0],
          nrows=sparse_result.dense_shape[0],
          validate=False)
    else:
      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
コード例 #9
0
  def testStringSplitEmptyToken(self):
    strings = [" hello ", "", "world "]

    with self.test_session() as sess:
      tokens = string_ops.string_split(strings)
      indices, values, shape = sess.run(tokens)
      self.assertAllEqual(indices, [[0, 0], [2, 0]])
      self.assertAllEqual(values, [b"hello", b"world"])
      self.assertAllEqual(shape, [3, 1])
コード例 #10
0
def string_split(source, sep=None, skip_empty=True, delimiter=None,
                 result_type="SparseTensor", name=None):  # pylint: disable=invalid-name
  """Split elements of `source` based on `delimiter`.

  Let N be the size of `source` (typically N will be the batch size). Split each
  element of `source` based on `delimiter` and return a `SparseTensor`
  or `RaggedTensor` containing the split tokens. Empty tokens are ignored.

  If `sep` is an empty string, each element of the `source` is split
  into individual strings, each containing one byte. (This includes splitting
  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
  treated as a set of delimiters with each considered a potential split point.

  Examples:

  >>> print(tf.compat.v1.string_split(['hello world', 'a b c']))
  SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
               values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
               dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))

  >>> print(tf.compat.v1.string_split(['hello world', 'a b c'],
  ...     result_type="RaggedTensor"))
  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>

  Args:
    source: `1-D` string `Tensor`, the strings to split.
    sep: `0-D` string `Tensor`, the delimiter character, the string should
      be length 0 or 1. Default is ' '.
    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
    delimiter: deprecated alias for `sep`.
    result_type: The tensor type for the result: one of `"RaggedTensor"` or
      `"SparseTensor"`.
    name: A name for the operation (optional).

  Raises:
    ValueError: If delimiter is not a string.

  Returns:
    A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
    to the delimiter.  The first column of the indices corresponds to the row
    in `source` and the second column corresponds to the index of the split
    component in this row.
  """
  with ops.name_scope(name, "StringSplit", [source]):
    sparse_result = string_ops.string_split(
        source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
    if result_type == "SparseTensor":
      return sparse_result
    elif result_type == "RaggedTensor":
      return ragged_tensor.RaggedTensor.from_value_rowids(
          values=sparse_result.values,
          value_rowids=sparse_result.indices[:, 0],
          nrows=sparse_result.dense_shape[0],
          validate=False)
    else:
      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
コード例 #11
0
  def testStringSplitEmptyToken(self):
    strings = ["", " a", "b ", " c", " ", " d ", "  e", "f  ", "  g  ", "  "]

    with self.test_session() as sess:
      tokens = string_ops.string_split(strings)
      indices, values, shape = sess.run(tokens)
      self.assertAllEqual(
          indices,
          [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
      self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
      self.assertAllEqual(shape, [10, 1])
コード例 #12
0
  def testStringSplitOnSetEmptyToken(self):
    strings = ["", " a", "b ", " c", " ", " d ", ". e", "f .", " .g. ", " ."]

    with self.cached_session() as sess:
      tokens = string_ops.string_split(strings, delimiter=" .")
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(
          indices,
          [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
      self.assertAllEqual(values, [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
      self.assertAllEqual(shape, [10, 1])
コード例 #13
0
  def testStringSplitWithDelimiter(self):
    strings = ["hello|world", "hello world"]

    with self.cached_session() as sess:
      self.assertRaises(
          ValueError, string_ops.string_split, strings, delimiter=["|", ""])

      self.assertRaises(
          ValueError, string_ops.string_split, strings, delimiter=["a"])

      tokens = string_ops.string_split(strings, delimiter="|")
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0]])
      self.assertAllEqual(values, [b"hello", b"world", b"hello world"])
      self.assertAllEqual(shape, [2, 2])

      tokens = string_ops.string_split(strings, delimiter="| ")
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0], [1, 1]])
      self.assertAllEqual(values, [b"hello", b"world", b"hello", b"world"])
      self.assertAllEqual(shape, [2, 2])
コード例 #14
0
    def testStringSplitEmptyToken(self):
        strings = [
            "", " a", "b ", " c", " ", " d ", "  e", "f  ", "  g  ", "  "
        ]

        with self.cached_session() as sess:
            tokens = string_ops.string_split(strings)
            indices, values, shape = sess.run(tokens)
            self.assertAllEqual(
                indices,
                [[1, 0], [2, 0], [3, 0], [5, 0], [6, 0], [7, 0], [8, 0]])
            self.assertAllEqual(values,
                                [b"a", b"b", b"c", b"d", b"e", b"f", b"g"])
            self.assertAllEqual(shape, [10, 1])
コード例 #15
0
  def testStringSplitEmptyDelimiter(self):
    strings = ["hello", "hola", b"\xF0\x9F\x98\x8E"]  # Last string is U+1F60E

    with self.cached_session() as sess:
      tokens = string_ops.string_split(strings, delimiter="")
      indices, values, shape = self.evaluate(tokens)
      self.assertAllEqual(indices, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4],
                                    [1, 0], [1, 1], [1, 2], [1, 3], [2, 0],
                                    [2, 1], [2, 2], [2, 3]])
      expected = np.array(
          [
              "h", "e", "l", "l", "o", "h", "o", "l", "a", b"\xf0", b"\x9f",
              b"\x98", b"\x8e"
          ],
          dtype="|S1")
      self.assertAllEqual(values.tolist(), expected)
      self.assertAllEqual(shape, [3, 5])
コード例 #16
0
  def testStringSplitWithDelimiterTensor(self):
    strings = ["hello|world", "hello world"]

    with self.cached_session() as sess:
      delimiter = array_ops.placeholder(dtypes.string)

      tokens = string_ops.string_split(strings, delimiter=delimiter)

      with self.assertRaises(errors_impl.InvalidArgumentError):
        sess.run(tokens, feed_dict={delimiter: ["a", "b"]})
      with self.assertRaises(errors_impl.InvalidArgumentError):
        sess.run(tokens, feed_dict={delimiter: ["a"]})
      indices, values, shape = sess.run(tokens, feed_dict={delimiter: "|"})

      self.assertAllEqual(indices, [[0, 0], [0, 1], [1, 0]])
      self.assertAllEqual(values, [b"hello", b"world", b"hello world"])
      self.assertAllEqual(shape, [2, 2])
コード例 #17
0
 def feature_engineering_fn(features, labels):
   # Github #12205: raise a TypeError if called twice.
   _ = string_ops.string_split(features["x"])
   features["x"] = constant_op.constant([9.])
   labels["y"] = constant_op.constant([99.])
   return features, labels
コード例 #18
0
 def feature_engineering_fn(features, labels):
     # Github #12205: raise a TypeError if called twice.
     _ = string_ops.string_split(features["x"])
     features["x"] = constant_op.constant([9.])
     labels["y"] = constant_op.constant([99.])
     return features, labels