def testCollapseRepeatedAllLabelsTheSame(self): collapsed, new_seq_lengths = ctc_ops.collapse_repeated( labels=[[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]], seq_length=[4, 5, 1]) self.assertAllEqual(new_seq_lengths, [1, 1, 1]) self.assertAllEqual( collapsed, [[1], [1], [1]])
def testCollapseRepeatedExtraPadding(self): collapsed, new_seq_lengths = ctc_ops.collapse_repeated( labels=[[1, 3, 3, 3, 0, 0, 0], [1, 4, 4, 4, 0, 1, 2], [4, 2, 2, 9, 4, 0, 0]], seq_length=[4, 5, 5]) self.assertAllEqual(new_seq_lengths, [2, 3, 4]) self.assertAllEqual( collapsed, [[1, 3, 0, 0], [1, 4, 0, 0], [4, 2, 9, 4]])
def testCollapseRepeatedFrontRepeats(self): collapsed, new_seq_lengths = ctc_ops.collapse_repeated( labels=[[1, 1, 1, 2, 2], [1, 1, 1, 2, 2], [1, 1, 1, 2, 2]], seq_length=[5, 4, 3]) self.assertAllEqual(new_seq_lengths, [2, 2, 1]) self.assertAllEqual( collapsed, [[1, 2], [1, 2], [1, 0]])
def testCollapseRepeatedPreservesDtypes(self): collapsed, new_seq_lengths = ctc_ops.collapse_repeated( labels=constant_op.constant( [[1, 3, 3, 3, 0], [1, 4, 4, 4, 0], [4, 2, 2, 9, 4]], dtype=dtypes.int64), seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64)) self.assertEqual(new_seq_lengths.dtype, dtypes.int64) self.assertEqual(collapsed.dtype, dtypes.int64) self.assertAllEqual(new_seq_lengths, [2, 3, 4]) self.assertAllEqual( collapsed, [[1, 3, 0, 0], [1, 4, 0, 0], [4, 2, 9, 4]])