Example #1
0
    def __init__(self, int2char, top_paths = 1, beam_width = 200, blank_index = 0, space_idx = 28,
                    lm_path=None, trie_path=None, dict_path=None, lm_alpha=None, lm_beta1=None, lm_beta2=None):
        self.beam_width = beam_width
        self.top_n = top_paths
        self.labels = ['#']
        int2phone = dict()
        for digit in int2char:
            if digit != 0:
                label = bytes.decode(int2char[digit].tostring())
                self.labels.append(label)
            int2phone[digit] = label
        int2phone[0] = '#'
        super(BeamDecoder, self).__init__(int2phone, space_idx=space_idx, blank_index=blank_index)
        self.label2 = '#123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLM'

        try:
            from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer
            import pytorch_ctc
        except ImportError:
            raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")

        if lm_path is not None:
            pytorch_ctc.generate_lm_trie(dict_path, lm_path, trie_path, self.label2, 0, -1)
            scorer = KenLMScorer(self.label2, lm_path, trie_path)
            scorer.set_lm_weight(lm_alpha)
            scorer.set_word_weight(lm_beta1)
            scorer.set_valid_word_weight(lm_beta2)
            print('hello')
        else:
            scorer = Scorer()
        self._decoder = CTCBeamDecoder(scorer = scorer, labels = self.labels, top_paths = top_paths, beam_width = beam_width, blank_index = blank_index, space_index = space_idx, merge_repeated=False)
Example #2
0
    def __init__(self,
                 labels,
                 beam_width=20,
                 top_paths=1,
                 blank_index=0,
                 space_index=28,
                 lm_path=None,
                 trie_path=None,
                 lm_alpha=None,
                 lm_beta1=None,
                 lm_beta2=None):
        super(BeamCTCDecoder, self).__init__(labels,
                                             blank_index=blank_index,
                                             space_index=space_index)
        self._beam_width = beam_width
        self._top_n = top_paths

        try:
            from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer
        except ImportError:
            raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")
        if lm_path is not None:
            scorer = KenLMScorer(labels, lm_path, trie_path)
            scorer.set_lm_weight(lm_alpha)
            scorer.set_word_weight(lm_beta1)
            scorer.set_valid_word_weight(lm_beta2)
        else:
            scorer = Scorer()
        self._decoder = CTCBeamDecoder(scorer,
                                       labels,
                                       top_paths=top_paths,
                                       beam_width=beam_width,
                                       blank_index=blank_index,
                                       space_index=space_index,
                                       merge_repeated=False)
Example #3
0
    def __init__(self, labels, beam_width=20, top_paths=1, blank_index=0, space_index=28, lm_path=None, trie_path=None,
                 lm_alpha=None, lm_beta1=None, lm_beta2=None):
        super(BeamCTCDecoder, self).__init__(labels, blank_index=blank_index, space_index=space_index)
        self._beam_width = beam_width
        self._top_n = top_paths

        try:
            from pytorch_ctc import CTCBeamDecoder, Scorer, KenLMScorer
        except ImportError:
            raise ImportError("BeamCTCDecoder requires pytorch_ctc package.")
        if lm_path is not None:
            scorer = KenLMScorer(labels, lm_path, trie_path)
            scorer.set_lm_weight(lm_alpha)
            scorer.set_word_weight(lm_beta1)
            scorer.set_valid_word_weight(lm_beta2)
        else:
            scorer = Scorer()
        self._decoder = CTCBeamDecoder(scorer, labels, top_paths=top_paths, beam_width=beam_width,
                                       blank_index=blank_index, space_index=space_index, merge_repeated=False)
def gen_decoded(feat_list, model_path):
    model = set_model_ctc.Layered_RNN(
        rnn_input_size=40,
        nb_layers=layers,
        rnn_hidden_size=hidden_size,
        bidirectional=True if num_dirs == 2 else False,
        batch_norm=True,
        num_classes=61)
    model = model.type(gpu_dtype)
    model.load_state_dict(torch.load(model_path))  # load model params
    model.eval(
    )  # Put the model in test mode (the opposite of model.train(), essentially)

    if decoder_type == 'Greedy':
        labels = create_mapping(mapping_file)
        decoder = ctc_decode.GreedyDecoder_test(
            labels, output='char', space_idx=-1)  # setup greedy decoder
    if decoder_type == 'Beam':
        labels = create_mapping(mapping_file)
        scorer = Scorer()
        decoder = ctc_decode.BeamDecoder_test(
            labels,
            scorer,
            top_paths=1,
            beam_width=200,
            output='char',
            space_idx=-1)  # setup beam decoder without lm
    if decoder_type == 'Beam_LM':
        labels_symbol = '_123456789abcde~-hij,.|{ofg?!+u}[x]@ABCDEFGHIJKLMNOPQRSTUVWXYZ'
        labels_true = create_mapping(mapping_file)
        # need to use the fake symbols here for consistency with the trie
        scorer = KenLMScorer(labels_symbol,
                             kenlm_path,
                             trie_path,
                             blank_index=0,
                             space_index=-1)
        scorer.set_lm_weight(lm_weight)
        scorer.set_word_weight(lm_beta1)
        scorer.set_valid_word_weight(lm_beta2)
        # need to use the true timit label to convert the decoded position indexes back to phone labels
        decoder = ctc_decode.BeamDecoder_test(
            labels_true,
            scorer,
            top_paths=1,
            beam_width=200,
            output='char',
            space_idx=-1)  # setup beam decoder with lm

    m, v = read_mv(stat_file)
    if m is None or v is None:
        raise Exception("mean or variance vector does not exist")

    with open(feat_list) as f:
        with open(out_mlf, 'w') as fw:
            fw.write('#!MLF!#\n')
            for line in f:
                line = line.strip()
                if len(line) < 1: continue
                print("recognizing file %s" % line)
                out_name = '"' + line[:line.rfind('.')] + '.rec' + '"'
                fw.write(out_name + '\n')
                io = htk_io.fopen(line)
                utt_feat = io.getall()
                utt_feat -= m  # normalize mean
                utt_feat /= (np.sqrt(v) + eps)  # normalize var
                feat_numpy = org_data(utt_feat, skip_frames=5)
                feat_tensor = torch.from_numpy(feat_numpy).type(gpu_dtype)
                x = Variable(feat_tensor.type(gpu_dtype), volatile=True)
                input_sizes_list = [x.size(1)]
                x = nn.utils.rnn.pack_padded_sequence(x,
                                                      input_sizes_list,
                                                      batch_first=True)
                probs = model(x, input_sizes_list)
                probs = probs.data.cpu()
                decoded = decoder.decode(probs, input_sizes_list)[0]
                for word in decoded:
                    fw.write(word + '\n')
                fw.write('.\n')
                print(' '.join(decoded))