def test_collapse_repeated_with_blanks(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[1, 0, 0, 2, 3], [1, 0, 1, 1, 2], [1, 0, 1, 0, 1]]), seq_length=jnp.array([5, 5, 5])) self.assertArraysEqual(new_seq_lengths, jnp.array([3, 3, 3])) self.assertArraysEqual( collapsed, jnp.array([[1, 2, 3, 0, 0], [1, 1, 2, 0, 0], [1, 1, 1, 0, 0]]))
def test_collapse_repeated_all_labels_the_same(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), seq_length=jnp.array([4, 5, 1])) self.assertArraysEqual(new_seq_lengths, jnp.array([1, 1, 1])) self.assertArraysEqual( collapsed, jnp.array([[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0, 0]]))
def test_collapse_repeated(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[1, 3, 3, 3, 0], [1, 4, 4, 4, 0], [4, 2, 2, 9, 4]]), seq_length=jnp.array([4, 5, 5])) self.assertArraysEqual(new_seq_lengths, jnp.array([2, 2, 4])) self.assertArraysEqual( collapsed, jnp.array([[1, 3, 0, 0, 0], [1, 4, 0, 0, 0], [4, 2, 9, 4, 0]]))
def test_first_item_is_blank(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[0, 0, 1, 0, 0, 2, 3], [0, 0, 1, 0, 1, 1, 2], [0, 0, 1, 0, 1, 0, 1]]), seq_length=jnp.array([7, 7, 7])) self.assertArraysEqual(new_seq_lengths, jnp.array([3, 3, 3])) self.assertArraysEqual( collapsed, jnp.array([[1, 2, 3, 0, 0, 0, 0], [1, 1, 2, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0]]))
def test_different_blank_id(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[1, -1, -1, 2, 3], [1, -1, 1, 0, 0], [1, -1, 1, 1, 0]]).astype(jnp.int32), seq_length=jnp.array([5, 3, 4]), blank_id=-1) self.assertArraysEqual(new_seq_lengths, jnp.array([3, 2, 2])) self.assertArraysEqual( collapsed, jnp.array([[1, 2, 3, 0, 0], [1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]))
def test_collapse_repeated_preserve_dtypes(self): collapsed, new_seq_lengths = ctc_objectives.collapse_and_remove_blanks( labels=jnp.array([[1, 3, 3, 3, 0], [1, 4, 4, 4, 0], [4, 2, 2, 9, 4]], dtype=jnp.int16), seq_length=jnp.array([4, 5, 5], dtype=jnp.int16)) self.assertEqual(new_seq_lengths.dtype, jnp.int16) self.assertEqual(collapsed.dtype, jnp.int16) self.assertArraysEqual(new_seq_lengths, jnp.array([2, 2, 4]).astype(jnp.int16)) self.assertArraysEqual( collapsed, jnp.array([[1, 3, 0, 0, 0], [1, 4, 0, 0, 0], [4, 2, 9, 4, 0]]).astype(jnp.int16))