예제 #1
0
    def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None):
        if not isinstance(utt_id, str):
            raise TypeError('utt_id must be a string(got {})'.format(
                type(utt_id)))
        if not isinstance(ref, str):
            raise TypeError('ref must be a string (got {})'.format(type(ref)))
        if not isinstance(pred, str):
            raise TypeError('pred must be a string(got {})'.format(type(pred)))

        # filter out any non_lang_syms from ref and pred
        non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None)
        assert non_lang_syms is None or isinstance(non_lang_syms, list)
        if non_lang_syms is not None and len(non_lang_syms) > 0:
            ref_list, pred_list = ref.strip().split(), pred.strip().split()
            ref = ' '.join([x for x in ref_list if x not in non_lang_syms])
            pred = ' '.join([x for x in pred_list if x not in non_lang_syms])

        # char level counts
        _, _, counter = speech_utils.edit_distance(
            ref.strip().split(),
            pred.strip().split(),
        )
        self.char_counter += counter

        # word level counts
        ref_words = self.dictionary.tokens_to_sentence(ref,
                                                       use_unk_sym=False,
                                                       bpe_symbol=bpe_symbol)
        pred_words = self.dictionary.tokens_to_sentence(pred,
                                                        bpe_symbol=bpe_symbol)

        # filter words according to self.word_filters (support re.sub only)
        for pattern, repl in self.word_filters:
            ref_words = re.sub(pattern, repl, ref_words)
            pred_words = re.sub(pattern, repl, pred_words)

        ref_word_list, pred_word_list = ref_words.split(), pred_words.split()
        _, steps, counter = speech_utils.edit_distance(
            ref_word_list,
            pred_word_list,
        )
        self.word_counter += counter
        assert utt_id not in self.aligned_results, \
            'Duplicated utterance id detected: {}'.format(utt_id)
        self.aligned_results[utt_id] = speech_utils.aligned_print(
            ref_word_list,
            pred_word_list,
            steps,
        )
예제 #2
0
파일: wer.py 프로젝트: valentinp72/espresso
    def add_evaluation(self, utt_id, ref, pred):
        if not isinstance(utt_id, str):
            raise TypeError("utt_id must be a string(got {})".format(
                type(utt_id)))
        if not isinstance(ref, str):
            raise TypeError("ref must be a string (got {})".format(type(ref)))
        if not isinstance(pred, str):
            raise TypeError("pred must be a string(got {})".format(type(pred)))

        # filter out any non_lang_syms from ref and pred
        non_lang_syms = getattr(self.dictionary, "non_lang_syms", None)
        assert non_lang_syms is None or isinstance(non_lang_syms, list)
        if non_lang_syms is not None and len(non_lang_syms) > 0:
            ref_list, pred_list = ref.strip().split(), pred.strip().split()
            ref = " ".join([x for x in ref_list if x not in non_lang_syms])
            pred = " ".join([x for x in pred_list if x not in non_lang_syms])

        # char level counts
        _, _, counter = speech_utils.edit_distance(
            ref.strip().split(),
            pred.strip().split(),
        )
        self.char_counter += counter

        # word level counts
        ref_words = self.dictionary.wordpiece_decode(ref)
        pred_words = self.dictionary.wordpiece_decode(pred)

        # filter words according to self.word_filters (support re.sub only)
        for pattern, repl in self.word_filters:
            ref_words = re.sub(pattern, repl, ref_words)
            pred_words = re.sub(pattern, repl, pred_words)

        ref_word_list, pred_word_list = ref_words.split(), pred_words.split()
        _, steps, counter = speech_utils.edit_distance(
            ref_word_list,
            pred_word_list,
        )
        self.word_counter += counter
        assert (utt_id not in self.aligned_results
                ), "Duplicated utterance id detected: {}".format(utt_id)
        self.aligned_results[utt_id] = speech_utils.aligned_print(
            ref_word_list,
            pred_word_list,
            steps,
        )