def testBadInputShapes(
      self, token_word, token_starts, token_ends, token_properties,
      is_ragged=True):
    constant = tf.ragged.constant if is_ragged else tf.constant
    token_starts = constant(token_starts, dtype=tf.int64)
    token_ends = constant(token_ends, dtype=tf.int64)
    token_properties = tf.ragged.constant(token_properties, dtype=tf.int64)

    with self.assertRaises(tf.errors.InvalidArgumentError):
      result = sentence_breaking_ops.sentence_fragments(
          token_word, token_starts, token_ends, token_properties)
      _ = self.evaluate(result)
  def testBadInputShapes(self,
                         test_description,
                         token_word,
                         token_starts,
                         token_ends,
                         token_properties,
                         is_ragged=True):
    constant = ragged_factory_ops.constant if is_ragged else constant_op.constant
    token_starts = constant(token_starts, dtype=dtypes.int64)
    token_ends = constant(token_ends, dtype=dtypes.int64)
    token_properties = ragged_factory_ops.constant(
        token_properties, dtype=dtypes.int64)

    with self.assertRaises(errors.InvalidArgumentError):
      result = sentence_breaking_ops.sentence_fragments(
          token_word, token_starts, token_ends, token_properties)
      _ = self.evaluate(result)
    def testDenseInputs(self, token_word, token_properties,
                        expected_fragment_start, expected_fragment_end,
                        expected_fragment_properties, expected_terminal_punc):
        token_starts, token_ends = self.getTokenOffsets(token_word)
        token_properties = constant_op.constant(token_properties,
                                                dtype=dtypes.int64)
        token_word = constant_op.constant(token_word, dtype=dtypes.string)

        fragments = sentence_breaking_ops.sentence_fragments(
            token_word, token_starts, token_ends, token_properties)

        fragment_starts, fragment_ends, fragment_properties, terminal_punc = (
            fragments)
        self.assertAllEqual(expected_fragment_start, fragment_starts)
        self.assertAllEqual(expected_fragment_end, fragment_ends)
        self.assertAllEqual(expected_fragment_properties, fragment_properties)
        self.assertAllEqual(expected_terminal_punc, terminal_punc)
  def testSentenceFragmentOp(
      self, text, token_starts, token_ends, token_properties,
      expected_fragment_start, expected_fragment_end,
      expected_fragment_properties, expected_terminal_punc):
    text = tf.constant(text)
    token_starts = tf.ragged.constant(token_starts, dtype=tf.int64)
    token_ends = tf.ragged.constant(token_ends, dtype=tf.int64)
    token_properties = tf.ragged.constant(token_properties, dtype=tf.int64)
    token_word = self.getTokenWord(text, token_starts, token_ends)

    fragments = sentence_breaking_ops.sentence_fragments(
        token_word, token_starts, token_ends, token_properties)

    fragment_starts, fragment_ends, fragment_properties, terminal_punc = (
        fragments)
    self.assertRaggedEqual(expected_fragment_start, fragment_starts)
    self.assertRaggedEqual(expected_fragment_end, fragment_ends)
    self.assertRaggedEqual(expected_fragment_properties,
                           fragment_properties)
    self.assertRaggedEqual(expected_terminal_punc, terminal_punc)
  def testSentenceFragmentOp(self, test_description, text, token_starts,
                             token_ends, token_properties,
                             expected_fragment_start, expected_fragment_end,
                             expected_fragment_properties,
                             expected_terminal_punc):
    text = constant_op.constant(text)
    token_starts = ragged_factory_ops.constant(token_starts, dtype=dtypes.int64)
    token_ends = ragged_factory_ops.constant(token_ends, dtype=dtypes.int64)
    token_properties = ragged_factory_ops.constant(
        token_properties, dtype=dtypes.int64)
    token_word = self.getTokenWord(text, token_starts, token_ends)

    fragments = sentence_breaking_ops.sentence_fragments(
        token_word, token_starts, token_ends, token_properties)

    fragment_starts, fragment_ends, fragment_properties, terminal_punc = (
        fragments)
    self.assertAllEqual(expected_fragment_start, fragment_starts)
    self.assertAllEqual(expected_fragment_end, fragment_ends)
    self.assertAllEqual(expected_fragment_properties, fragment_properties)
    self.assertAllEqual(expected_terminal_punc, terminal_punc)