コード例 #1
0
    def __init__(self, model_name, trans_df):

        from espnet2.bin.asr_inference import Speech2Text
        from espnet_model_zoo.downloader import ModelDownloader
        import jiwer

        self.model_name = model_name
        d = ModelDownloader()
        self.asr_model = Speech2Text(**d.download_and_unpack(model_name))
        self.input_txt_list = []
        self.clean_txt_list = []
        self.output_txt_list = []
        self.transcriptions = []
        self.true_txt_list = []
        self.sample_rate = int(
            d.data_frame[d.data_frame["name"] == model_name]["fs"])
        self.trans_df = trans_df
        self.trans_dic = self._df_to_dict(trans_df)
        self.mix_counter = Counter()
        self.clean_counter = Counter()
        self.est_counter = Counter()
        self.transformation = jiwer.Compose([
            jiwer.ToLowerCase(),
            jiwer.RemovePunctuation(),
            jiwer.RemoveMultipleSpaces(),
            jiwer.Strip(),
            jiwer.SentencesToListOfWords(),
            jiwer.RemoveEmptyStrings(),
        ])
コード例 #2
0
    def _calc_metrics(self, ground_truth, hypothesis):
        transformation = jiwer.Compose([
            jiwer.ToLowerCase(),
            jiwer.RemoveMultipleSpaces(),
            jiwer.RemoveWhiteSpace(replace_by_space=" "),
            jiwer.SentencesToListOfWords(word_delimiter=" ")
        ])

        mer = jiwer.mer(ground_truth,
                        hypothesis,
                        truth_transform=transformation,
                        hypothesis_transform=transformation)

        wer = jiwer.wer(ground_truth,
                        hypothesis,
                        truth_transform=transformation,
                        hypothesis_transform=transformation)

        wil = jiwer.wil(ground_truth,
                        hypothesis,
                        truth_transform=transformation,
                        hypothesis_transform=transformation)

        wip = jiwer.wip(ground_truth,
                        hypothesis,
                        truth_transform=transformation,
                        hypothesis_transform=transformation)

        return mer, wer, wil, wip
コード例 #3
0
def _str_clean(input_string: str) -> str:
    """
    Use jiwer's
    """
    transformation = jiwer.Compose([
        jiwer.ToLowerCase(),
        jiwer.RemoveMultipleSpaces(),
        jiwer.RemoveWhiteSpace(replace_by_space=True),
        jiwer.SentencesToListOfWords(word_delimiter=" ")
    ])
    return transformation(input_string)
コード例 #4
0
def compute_perc_script_missing(original_script, transcript, language):
    '''
    Check how much of original_script is missing in transcript. Clean and remove stopwords
    '''
    # print(original_script)
    # print(transcript)

    cleaning = jiwer.Compose([
        jiwer.SubstituteRegexes({"¡": "", "¿":"", "á": "a", "é": "e", "í": "i", "ó": "o","ú": "u"}),
        jiwer.SubstituteWords({ "tardes": "dias",
                                "noches": "dias",
                                " uno ": " 1 ",
                                " dos ": " 2 ",
                                " tres ": " 3 ",
                                " cuatro ": " 4 ",
                                " cinco ": " 5 ",
                                " seis ": " 6 ",
                                " siete ": " 7 ",
                                " ocho ": " 8 ",
                                " nueve ": " 9 "}),
        jiwer.RemovePunctuation(),
        jiwer.ToLowerCase(),
        jiwer.SentencesToListOfWords(word_delimiter=" "),
        jiwer.RemoveEmptyStrings()
    ])

    #Remove anything between ${variable} from original_script
    original_script_transformed = re.sub(r'\${.*?\}','',original_script)
    # print(original_script_transformed)
    #Clean both
    original_script_transformed = cleaning(original_script_transformed)
    transcript_transformed = cleaning(transcript)
    # print(original_script_transformed)


    #Remove stopwords from original_script
    original_script_transformed_no_stopwords = remove_stopwords(original_script_transformed, language)
    if len(original_script_transformed_no_stopwords) != 0: #Sometimes removing stopwords removes all words from script
        original_script_transformed = original_script_transformed_no_stopwords

    #Lemmatize transcript
    stemmer = get_stemmer(language)
    transcript_transformed_stem = [stemmer.stem(word) for word in transcript_transformed]

    #Get words form original_script_transformed whose stem is not in transcript_transformed_stem
    words_missing = [word for word in original_script_transformed if stemmer.stem(word) not in transcript_transformed_stem]

    return len(words_missing)/len(original_script_transformed), words_missing
コード例 #5
0
def sentence_wer(reference: str, prediction: str):
    transformation = jiwer.Compose([
        jiwer.RemoveMultipleSpaces(),
        jiwer.RemovePunctuation(),
        jiwer.Strip(),
        jiwer.ToLowerCase(),
        jiwer.ExpandCommonEnglishContractions(),
        jiwer.RemoveWhiteSpace(replace_by_space=True),
        jiwer.SentencesToListOfWords(),
        jiwer.RemoveEmptyStrings(),
    ])

    return jiwer.wer(reference.strip(),
                     prediction.strip(),
                     truth_transform=transformation,
                     hypothesis_transform=transformation)
コード例 #6
0
def metric(ref_trans, asr_trans, lang):
    if lang == "en":
        transformation = jiwer.Compose([
            jiwer.Strip(),
            jiwer.ToLowerCase(),
            jiwer.RemoveWhiteSpace(replace_by_space=True),
            jiwer.RemoveMultipleSpaces(),
            jiwer.SentencesToListOfWords(word_delimiter=" "),
            jiwer.RemoveEmptyStrings(),
            jiwer.RemovePunctuation(),
        ])
        wer = jiwer.wer(
            ref_trans,
            asr_trans,
            truth_transform=transformation,
            hypothesis_transform=transformation,
        )
    elif lang == "cn":
        del_symblos = re.compile(r"[^\u4e00-\u9fa5]+")
        for idx in range(len(asr_trans)):
            sentence = re.sub(del_symblos, "", asr_trans[idx])
            sentence = list(sentence)
            sentence = " ".join(sentence)
            asr_trans[idx] = sentence

            sentence = re.sub(del_symblos, "", ref_trans[idx])
            sentence = list(sentence)
            sentence = " ".join(sentence)
            ref_trans[idx] = sentence
        asr_valid = set(asr_trans)
        assert len(asr_valid) == len(asr_trans)
        wer = jiwer.wer(ref_trans, asr_trans)

    else:
        raise ("Args error!")
    return wer
コード例 #7
0
                    if (tstart - end) > 0.5:
                        srt.push(next_sub)
                        break
                    end = next_sub.end.hours * 3600 + next_sub.end.minutes * 60 + next_sub.end.seconds + next_sub.end.milliseconds / 1000

                    ground_truth = ground_truth + " " + next_sub.text_without_tags
                    hypothesis = kd.query_text(start, end)
                else:
                    break
            kd.mark_words(start, end)

            transformation = jiwer.Compose([
                jiwer.ToLowerCase(),
                jiwer.RemoveMultipleSpaces(),
                jiwer.RemoveWhiteSpace(replace_by_space=True),
                jiwer.SentencesToListOfWords(),
                jiwer.RemovePunctuation(),
                jiwer.RemoveEmptyStrings(),
                jiwer.SubstituteRegexes({r"ё": r"е"})
            ])
            gt = transformation([ground_truth])
            hp = transformation([hypothesis])

            gt, hp = replace_pairs(gt, hp)
            hp, gt = replace_pairs(hp, gt)

            wer(gt, hp)

            r = jiwer.compute_measures(
                gt,
                hp