コード例 #1
0
 def testCollapseRepeatedAllLabelsTheSame(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 1, 1, 1, 1],
               [1, 1, 1, 1, 1],
               [1, 1, 1, 1, 1]],
       seq_length=[4, 5, 1])
   self.assertAllEqual(new_seq_lengths, [1, 1, 1])
   self.assertAllEqual(
       collapsed,
       [[1],
        [1],
        [1]])
コード例 #2
0
 def testCollapseRepeatedExtraPadding(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 3, 3, 3, 0, 0, 0],
               [1, 4, 4, 4, 0, 1, 2],
               [4, 2, 2, 9, 4, 0, 0]],
       seq_length=[4, 5, 5])
   self.assertAllEqual(new_seq_lengths, [2, 3, 4])
   self.assertAllEqual(
       collapsed,
       [[1, 3, 0, 0],
        [1, 4, 0, 0],
        [4, 2, 9, 4]])
コード例 #3
0
 def testCollapseRepeatedFrontRepeats(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 1, 1, 2, 2],
               [1, 1, 1, 2, 2],
               [1, 1, 1, 2, 2]],
       seq_length=[5, 4, 3])
   self.assertAllEqual(new_seq_lengths, [2, 2, 1])
   self.assertAllEqual(
       collapsed,
       [[1, 2],
        [1, 2],
        [1, 0]])
コード例 #4
0
 def testCollapseRepeatedAllLabelsTheSame(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 1, 1, 1, 1],
               [1, 1, 1, 1, 1],
               [1, 1, 1, 1, 1]],
       seq_length=[4, 5, 1])
   self.assertAllEqual(new_seq_lengths, [1, 1, 1])
   self.assertAllEqual(
       collapsed,
       [[1],
        [1],
        [1]])
コード例 #5
0
 def testCollapseRepeatedFrontRepeats(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 1, 1, 2, 2],
               [1, 1, 1, 2, 2],
               [1, 1, 1, 2, 2]],
       seq_length=[5, 4, 3])
   self.assertAllEqual(new_seq_lengths, [2, 2, 1])
   self.assertAllEqual(
       collapsed,
       [[1, 2],
        [1, 2],
        [1, 0]])
コード例 #6
0
 def testCollapseRepeatedExtraPadding(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=[[1, 3, 3, 3, 0, 0, 0],
               [1, 4, 4, 4, 0, 1, 2],
               [4, 2, 2, 9, 4, 0, 0]],
       seq_length=[4, 5, 5])
   self.assertAllEqual(new_seq_lengths, [2, 3, 4])
   self.assertAllEqual(
       collapsed,
       [[1, 3, 0, 0],
        [1, 4, 0, 0],
        [4, 2, 9, 4]])
コード例 #7
0
 def testCollapseRepeatedPreservesDtypes(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=constant_op.constant(
           [[1, 3, 3, 3, 0],
            [1, 4, 4, 4, 0],
            [4, 2, 2, 9, 4]],
           dtype=dtypes.int64),
       seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64))
   self.assertEqual(new_seq_lengths.dtype, dtypes.int64)
   self.assertEqual(collapsed.dtype, dtypes.int64)
   self.assertAllEqual(new_seq_lengths, [2, 3, 4])
   self.assertAllEqual(
       collapsed,
       [[1, 3, 0, 0],
        [1, 4, 0, 0],
        [4, 2, 9, 4]])
コード例 #8
0
 def testCollapseRepeatedPreservesDtypes(self):
   collapsed, new_seq_lengths = ctc_ops.collapse_repeated(
       labels=constant_op.constant(
           [[1, 3, 3, 3, 0],
            [1, 4, 4, 4, 0],
            [4, 2, 2, 9, 4]],
           dtype=dtypes.int64),
       seq_length=constant_op.constant([4, 5, 5], dtype=dtypes.int64))
   self.assertEqual(new_seq_lengths.dtype, dtypes.int64)
   self.assertEqual(collapsed.dtype, dtypes.int64)
   self.assertAllEqual(new_seq_lengths, [2, 3, 4])
   self.assertAllEqual(
       collapsed,
       [[1, 3, 0, 0],
        [1, 4, 0, 0],
        [4, 2, 9, 4]])