Пример #1
0
 def test_collapse_repeated_with_blanks(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[1, 0, 0, 2, 3], [1, 0, 1, 1, 2], [1, 0, 1, 0, 1]]),
       seq_length=jnp.array([5, 5, 5]))
   self.assertArraysEqual(new_seq_lengths, jnp.array([3, 3, 3]))
   self.assertArraysEqual(
       collapsed, jnp.array([[1, 2, 3, 0, 0], [1, 1, 2, 0, 0], [1, 1, 1, 0,
                                                                0]]))
Пример #2
0
 def test_collapse_repeated_all_labels_the_same(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]),
       seq_length=jnp.array([4, 5, 1]))
   self.assertArraysEqual(new_seq_lengths, jnp.array([1, 1, 1]))
   self.assertArraysEqual(
       collapsed, jnp.array([[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0,
                                                                0]]))
Пример #3
0
 def test_collapse_repeated(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[1, 3, 3, 3, 0], [1, 4, 4, 4, 0], [4, 2, 2, 9, 4]]),
       seq_length=jnp.array([4, 5, 5]))
   self.assertArraysEqual(new_seq_lengths, jnp.array([2, 2, 4]))
   self.assertArraysEqual(
       collapsed, jnp.array([[1, 3, 0, 0, 0], [1, 4, 0, 0, 0], [4, 2, 9, 4,
                                                                0]]))
Пример #4
0
 def test_first_item_is_blank(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[0, 0, 1, 0, 0, 2, 3], [0, 0, 1, 0, 1, 1, 2],
                         [0, 0, 1, 0, 1, 0, 1]]),
       seq_length=jnp.array([7, 7, 7]))
   self.assertArraysEqual(new_seq_lengths, jnp.array([3, 3, 3]))
   self.assertArraysEqual(
       collapsed,
       jnp.array([[1, 2, 3, 0, 0, 0, 0], [1, 1, 2, 0, 0, 0, 0],
                  [1, 1, 1, 0, 0, 0, 0]]))
Пример #5
0
 def test_different_blank_id(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[1, -1, -1, 2, 3], [1, -1, 1, 0, 0],
                         [1, -1, 1, 1, 0]]).astype(jnp.int32),
       seq_length=jnp.array([5, 3, 4]),
       blank_id=-1)
   self.assertArraysEqual(new_seq_lengths, jnp.array([3, 2, 2]))
   self.assertArraysEqual(
       collapsed, jnp.array([[1, 2, 3, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0,
                                                                0]]))
Пример #6
0
 def test_collapse_repeated_preserve_dtypes(self):
   collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks(
       labels=jnp.array([[1, 3, 3, 3, 0], [1, 4, 4, 4, 0], [4, 2, 2, 9, 4]],
                        dtype=jnp.int16),
       seq_length=jnp.array([4, 5, 5], dtype=jnp.int16))
   self.assertEqual(new_seq_lengths.dtype, jnp.int16)
   self.assertEqual(collapsed.dtype, jnp.int16)
   self.assertArraysEqual(new_seq_lengths,
                          jnp.array([2, 2, 4]).astype(jnp.int16))
   self.assertArraysEqual(
       collapsed,
       jnp.array([[1, 3, 0, 0, 0], [1, 4, 0, 0, 0], [4, 2, 9, 4,
                                                     0]]).astype(jnp.int16))