class BeamCTCDecoder(Decoder): def __init__(self, labels, scorer, beam_width=20, top_paths=1, blank_index=0, space_index=28): super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) self._beam_width = beam_width self._top_n = top_paths try: import pytorch_ctc except ImportError: raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") self._decoder = CTCBD(scorer, labels, top_paths=top_paths, beam_width=beam_width, blank_index=blank_index, space_index=space_index, merge_repeated=False) def decode(self, probs, sizes=None): sizes = sizes.cpu() if sizes is not None else None out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes) # TODO: support returning multiple paths strings = self.convert_to_strings(out[0], sizes=seq_len[0]) return self.process_strings(strings) def decode_multiple_paths(self, probs, sizes=None): sizes = sizes.cpu() if sizes is not None else None out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes) # TODO: support returning multiple paths strings_list = [self.convert_to_strings(out[i], sizes=seq_len[i]) for i in range(self._top_n)] return [self.process_strings(s) for s in strings_list]
class BeamCTCDecoder(Decoder): def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28, lm_path=None, trie_path=None, lm_alpha=None, lm_beta1=None, lm_beta2=None): super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index) self._beam_width = beam_width self._top_n = top_paths try: from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer except ImportError: raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") if lm_path is not None: scorer = KenLMScorer(labels, lm_path, trie_path) scorer.set_lm_weight(lm_alpha) scorer.set_word_weight(lm_beta1) scorer.set_valid_word_weight(lm_beta2) else: scorer = Scorer() self._decoder = CTCBeamDecoder(scorer, labels, top_paths=top_paths, beam_width=beam_width, blank_index=blank_index, space_index=space_index, merge_repeated=False) def decode(self, probs, sizes=None): sizes = sizes.cpu() if sizes is not None else None out, conf, seq_len = self._decoder.decode(probs.cpu(), sizes) # TODO: support returning multiple paths strings = self.convert_to_strings(out[0], sizes=seq_len[0]) return self.process_strings(strings)
class BeamDecoder(Decoder): def __init__(self, int2char, top_paths = 1, beam_width = 200, blank_index = 0, space_idx = 28, lm_path=None, trie_path=None, dict_path=None, lm_alpha=None, lm_beta1=None, lm_beta2=None): self.beam_width = beam_width self.top_n = top_paths self.labels = ['#'] int2phone = dict() for digit in int2char: if digit != 0: label = bytes.decode(int2char[digit].tostring()) self.labels.append(label) int2phone[digit] = label int2phone[0] = '#' super(BeamDecoder, self).__init__(int2phone, space_idx=space_idx, blank_index=blank_index) self.label2 = '#123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLM' try: from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer import pytorch_ctc except ImportError: raise ImportError("BeamCTCDecoder requires pytorch_ctc package.") if lm_path is not None: pytorch_ctc.generate_lm_trie(dict_path, lm_path, trie_path, self.label2, 0, -1) scorer = KenLMScorer(self.label2, lm_path, trie_path) scorer.set_lm_weight(lm_alpha) scorer.set_word_weight(lm_beta1) scorer.set_valid_word_weight(lm_beta2) print('hello') else: scorer = Scorer() self._decoder = CTCBeamDecoder(scorer = scorer, labels = self.labels, top_paths = top_paths, beam_width = beam_width, blank_index = blank_index, space_index = space_idx, merge_repeated=False) def decode(self, prob_tensor, frame_seq_len): frame_seq_len = torch.IntTensor(frame_seq_len).cpu() decoded, _, out_seq_len = self._decoder.decode(prob_tensor, seq_len = frame_seq_len) decoded = decoded[0] out_seq_len = out_seq_len[0] decoded = self._convert_to_strings(decoded, out_seq_len) return self._process_strings(decoded)