def testBadInputShapes( self, token_word, token_starts, token_ends, token_properties, is_ragged=True): constant = tf.ragged.constant if is_ragged else tf.constant token_starts = constant(token_starts, dtype=tf.int64) token_ends = constant(token_ends, dtype=tf.int64) token_properties = tf.ragged.constant(token_properties, dtype=tf.int64) with self.assertRaises(tf.errors.InvalidArgumentError): result = sentence_breaking_ops.sentence_fragments( token_word, token_starts, token_ends, token_properties) _ = self.evaluate(result)
def testBadInputShapes(self, test_description, token_word, token_starts, token_ends, token_properties, is_ragged=True): constant = ragged_factory_ops.constant if is_ragged else constant_op.constant token_starts = constant(token_starts, dtype=dtypes.int64) token_ends = constant(token_ends, dtype=dtypes.int64) token_properties = ragged_factory_ops.constant( token_properties, dtype=dtypes.int64) with self.assertRaises(errors.InvalidArgumentError): result = sentence_breaking_ops.sentence_fragments( token_word, token_starts, token_ends, token_properties) _ = self.evaluate(result)
def testDenseInputs(self, token_word, token_properties, expected_fragment_start, expected_fragment_end, expected_fragment_properties, expected_terminal_punc): token_starts, token_ends = self.getTokenOffsets(token_word) token_properties = constant_op.constant(token_properties, dtype=dtypes.int64) token_word = constant_op.constant(token_word, dtype=dtypes.string) fragments = sentence_breaking_ops.sentence_fragments( token_word, token_starts, token_ends, token_properties) fragment_starts, fragment_ends, fragment_properties, terminal_punc = ( fragments) self.assertAllEqual(expected_fragment_start, fragment_starts) self.assertAllEqual(expected_fragment_end, fragment_ends) self.assertAllEqual(expected_fragment_properties, fragment_properties) self.assertAllEqual(expected_terminal_punc, terminal_punc)
def testSentenceFragmentOp( self, text, token_starts, token_ends, token_properties, expected_fragment_start, expected_fragment_end, expected_fragment_properties, expected_terminal_punc): text = tf.constant(text) token_starts = tf.ragged.constant(token_starts, dtype=tf.int64) token_ends = tf.ragged.constant(token_ends, dtype=tf.int64) token_properties = tf.ragged.constant(token_properties, dtype=tf.int64) token_word = self.getTokenWord(text, token_starts, token_ends) fragments = sentence_breaking_ops.sentence_fragments( token_word, token_starts, token_ends, token_properties) fragment_starts, fragment_ends, fragment_properties, terminal_punc = ( fragments) self.assertRaggedEqual(expected_fragment_start, fragment_starts) self.assertRaggedEqual(expected_fragment_end, fragment_ends) self.assertRaggedEqual(expected_fragment_properties, fragment_properties) self.assertRaggedEqual(expected_terminal_punc, terminal_punc)
def testSentenceFragmentOp(self, test_description, text, token_starts, token_ends, token_properties, expected_fragment_start, expected_fragment_end, expected_fragment_properties, expected_terminal_punc): text = constant_op.constant(text) token_starts = ragged_factory_ops.constant(token_starts, dtype=dtypes.int64) token_ends = ragged_factory_ops.constant(token_ends, dtype=dtypes.int64) token_properties = ragged_factory_ops.constant( token_properties, dtype=dtypes.int64) token_word = self.getTokenWord(text, token_starts, token_ends) fragments = sentence_breaking_ops.sentence_fragments( token_word, token_starts, token_ends, token_properties) fragment_starts, fragment_ends, fragment_properties, terminal_punc = ( fragments) self.assertAllEqual(expected_fragment_start, fragment_starts) self.assertAllEqual(expected_fragment_end, fragment_ends) self.assertAllEqual(expected_fragment_properties, fragment_properties) self.assertAllEqual(expected_terminal_punc, terminal_punc)