Beispiel #1
0
    def test_online_decoder_decoding_with_two_calls_no_lm(self):
        decoder = ctcdecode.OnlineCTCBeamDecoder(
            self.vocab_list,
            beam_width=self.beam_size,
            blank_id=self.vocab_list.index("_"),
            log_probs_input=True,
            num_processes=24,
        )
        state1 = ctcdecode.DecoderState(decoder)
        state2 = ctcdecode.DecoderState(decoder)

        probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log()

        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq[:, :2], [state1, state2], [False, False])
        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq[:, 2:], [state1, state2], [True, True])

        del state1, state2
        size = beam_results.shape
        output_str1 = self.convert_to_string(beam_results[0][0],
                                             self.vocab_list,
                                             out_seq_len[0][0])
        output_str2 = self.convert_to_string(beam_results[1][0],
                                             self.vocab_list,
                                             out_seq_len[1][0])

        self.assertEqual(output_str1, self.beam_search_result[0])
        self.assertEqual(output_str2, self.beam_search_result[1])
Beispiel #2
0
    def test_online_decoder_decoding_no_lm(self):
        decoder = ctcdecode.OnlineCTCBeamDecoder(
            self.vocab_list,
            beam_width=self.beam_size,
            blank_id=self.vocab_list.index("_"),
            log_probs_input=True,
            num_processes=24,
        )
        state1 = ctcdecode.DecoderState(decoder)
        state2 = ctcdecode.DecoderState(decoder)

        probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log()

        is_eos_s = [True for _ in range(len(probs_seq))]

        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq, [state1, state2], is_eos_s)
        output_str1 = self.convert_to_string(beam_results[0][0],
                                             self.vocab_list,
                                             out_seq_len[0][0])
        output_str2 = self.convert_to_string(beam_results[1][0],
                                             self.vocab_list,
                                             out_seq_len[1][0])

        self.assertEqual(output_str1, self.beam_search_result[0])
        self.assertEqual(output_str2, self.beam_search_result[1])
Beispiel #3
0
    def test_online_decoder_decoding_with_a_lot_calls_no_lm_check_size(self):
        decoder = ctcdecode.OnlineCTCBeamDecoder(
            self.vocab_list,
            beam_width=self.beam_size,
            blank_id=self.vocab_list.index("_"),
            log_probs_input=True,
            num_processes=24,
        )
        state1 = ctcdecode.DecoderState(decoder)

        probs_seq = torch.FloatTensor([self.probs_seq1]).log()

        for i in range(1000):
            beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
                probs_seq, [state1], [False, False])

        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq, [state1], [True, True])

        del state1
        self.assertGreaterEqual(beam_results.shape[2], out_seq_len.max())
Beispiel #4
0
    def test_online_decoder_decoding_with_two_calls(self):
        lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                               "test.arpa")
        decoder = ctcdecode.OnlineCTCBeamDecoder(
            self.vocab_list,
            beam_width=self.beam_size,
            blank_id=self.vocab_list.index("_"),
            log_probs_input=True,
            num_processes=24,
            model_path=lm_path,
        )
        state1 = ctcdecode.DecoderState(decoder)

        probs_seq = torch.FloatTensor([self.probs_seq2]).log()

        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq[:, :2], [state1], [False])
        beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(
            probs_seq[:, 2:], [state1], [True])

        output_str1 = self.convert_to_string(beam_results[0][0],
                                             self.vocab_list,
                                             out_seq_len[0][0])
        self.assertEqual(output_str1, self.beam_search_result[2])