def test_online_decoder_decoding_with_two_calls_no_lm(self): decoder = ctcdecode.OnlineCTCBeamDecoder( self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), log_probs_input=True, num_processes=24, ) state1 = ctcdecode.DecoderState(decoder) state2 = ctcdecode.DecoderState(decoder) probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq[:, :2], [state1, state2], [False, False]) beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq[:, 2:], [state1, state2], [True, True]) del state1, state2 size = beam_results.shape output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) self.assertEqual(output_str1, self.beam_search_result[0]) self.assertEqual(output_str2, self.beam_search_result[1])
def test_online_decoder_decoding_no_lm(self): decoder = ctcdecode.OnlineCTCBeamDecoder( self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), log_probs_input=True, num_processes=24, ) state1 = ctcdecode.DecoderState(decoder) state2 = ctcdecode.DecoderState(decoder) probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() is_eos_s = [True for _ in range(len(probs_seq))] beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq, [state1, state2], is_eos_s) output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) self.assertEqual(output_str1, self.beam_search_result[0]) self.assertEqual(output_str2, self.beam_search_result[1])
def test_online_decoder_decoding_with_a_lot_calls_no_lm_check_size(self): decoder = ctcdecode.OnlineCTCBeamDecoder( self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), log_probs_input=True, num_processes=24, ) state1 = ctcdecode.DecoderState(decoder) probs_seq = torch.FloatTensor([self.probs_seq1]).log() for i in range(1000): beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq, [state1], [False, False]) beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq, [state1], [True, True]) del state1 self.assertGreaterEqual(beam_results.shape[2], out_seq_len.max())
def test_online_decoder_decoding_with_two_calls(self): lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") decoder = ctcdecode.OnlineCTCBeamDecoder( self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), log_probs_input=True, num_processes=24, model_path=lm_path, ) state1 = ctcdecode.DecoderState(decoder) probs_seq = torch.FloatTensor([self.probs_seq2]).log() beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq[:, :2], [state1], [False]) beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( probs_seq[:, 2:], [state1], [True]) output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) self.assertEqual(output_str1, self.beam_search_result[2])