def _calc_metrics(self, ground_truth, hypothesis): transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.RemoveWhiteSpace(replace_by_space=" "), jiwer.SentencesToListOfWords(word_delimiter=" ") ]) mer = jiwer.mer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wer = jiwer.wer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wil = jiwer.wil(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) wip = jiwer.wip(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) return mer, wer, wil, wip
def __init__(self, model_name, trans_df): from espnet2.bin.asr_inference import Speech2Text from espnet_model_zoo.downloader import ModelDownloader import jiwer self.model_name = model_name d = ModelDownloader() self.asr_model = Speech2Text(**d.download_and_unpack(model_name)) self.input_txt_list = [] self.clean_txt_list = [] self.output_txt_list = [] self.transcriptions = [] self.true_txt_list = [] self.sample_rate = int( d.data_frame[d.data_frame["name"] == model_name]["fs"]) self.trans_df = trans_df self.trans_dic = self._df_to_dict(trans_df) self.mix_counter = Counter() self.clean_counter = Counter() self.est_counter = Counter() self.transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemovePunctuation(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings(), ])
def _str_clean(input_string: str) -> str: """ Use jiwer's """ transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.SentencesToListOfWords(word_delimiter=" ") ]) return transformation(input_string)
def calc_wer(ground_truth, hypothesis): transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.ExpandCommonEnglishContractions(), jiwer.RemovePunctuation() ]) wer = jiwer.wer(ground_truth, hypothesis, truth_transform=transformation, hypothesis_transform=transformation) return wer
def compute_perc_script_missing(original_script, transcript, language): ''' Check how much of original_script is missing in transcript. Clean and remove stopwords ''' # print(original_script) # print(transcript) cleaning = jiwer.Compose([ jiwer.SubstituteRegexes({"¡": "", "¿":"", "á": "a", "é": "e", "í": "i", "ó": "o","ú": "u"}), jiwer.SubstituteWords({ "tardes": "dias", "noches": "dias", " uno ": " 1 ", " dos ": " 2 ", " tres ": " 3 ", " cuatro ": " 4 ", " cinco ": " 5 ", " seis ": " 6 ", " siete ": " 7 ", " ocho ": " 8 ", " nueve ": " 9 "}), jiwer.RemovePunctuation(), jiwer.ToLowerCase(), jiwer.SentencesToListOfWords(word_delimiter=" "), jiwer.RemoveEmptyStrings() ]) #Remove anything between ${variable} from original_script original_script_transformed = re.sub(r'\${.*?\}','',original_script) # print(original_script_transformed) #Clean both original_script_transformed = cleaning(original_script_transformed) transcript_transformed = cleaning(transcript) # print(original_script_transformed) #Remove stopwords from original_script original_script_transformed_no_stopwords = remove_stopwords(original_script_transformed, language) if len(original_script_transformed_no_stopwords) != 0: #Sometimes removing stopwords removes all words from script original_script_transformed = original_script_transformed_no_stopwords #Lemmatize transcript stemmer = get_stemmer(language) transcript_transformed_stem = [stemmer.stem(word) for word in transcript_transformed] #Get words form original_script_transformed whose stem is not in transcript_transformed_stem words_missing = [word for word in original_script_transformed if stemmer.stem(word) not in transcript_transformed_stem] return len(words_missing)/len(original_script_transformed), words_missing
def sentence_wer(reference: str, prediction: str): transformation = jiwer.Compose([ jiwer.RemoveMultipleSpaces(), jiwer.RemovePunctuation(), jiwer.Strip(), jiwer.ToLowerCase(), jiwer.ExpandCommonEnglishContractions(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings(), ]) return jiwer.wer(reference.strip(), prediction.strip(), truth_transform=transformation, hypothesis_transform=transformation)
def evaluate(testset, audio_directory): model = deepspeech.Model('deepspeech-0.7.0-models.pbmm') model.enableExternalScorer('deepspeech-0.7.0-models.scorer') predictions = [] targets = [] for i, datapoint in enumerate(testset): audio, rate = sf.read( os.path.join(audio_directory, f'example_output_{i}.wav')) assert rate == model.sampleRate(), 'wrong sample rate' audio_int16 = (audio * (2**15)).astype(np.int16) text = model.stt(audio_int16) predictions.append(text) target_text = unidecode(datapoint['text']) targets.append(target_text) transformation = jiwer.Compose( [jiwer.RemovePunctuation(), jiwer.ToLowerCase()]) targets = transformation(targets) predictions = transformation(predictions) logging.info(f'targets: {targets}') logging.info(f'predictions: {predictions}') logging.info(f'wer: {jiwer.wer(targets, predictions)}')
def metric(ref_trans, asr_trans, lang): if lang == "en": transformation = jiwer.Compose([ jiwer.Strip(), jiwer.ToLowerCase(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.RemoveMultipleSpaces(), jiwer.SentencesToListOfWords(word_delimiter=" "), jiwer.RemoveEmptyStrings(), jiwer.RemovePunctuation(), ]) wer = jiwer.wer( ref_trans, asr_trans, truth_transform=transformation, hypothesis_transform=transformation, ) elif lang == "cn": del_symblos = re.compile(r"[^\u4e00-\u9fa5]+") for idx in range(len(asr_trans)): sentence = re.sub(del_symblos, "", asr_trans[idx]) sentence = list(sentence) sentence = " ".join(sentence) asr_trans[idx] = sentence sentence = re.sub(del_symblos, "", ref_trans[idx]) sentence = list(sentence) sentence = " ".join(sentence) ref_trans[idx] = sentence asr_valid = set(asr_trans) assert len(asr_valid) == len(asr_trans) wer = jiwer.wer(ref_trans, asr_trans) else: raise ("Args error!") return wer
def analyze(): try: req_data = request.get_json() compose_rule_set = [] if req_data.get('to_lower_case', False) == True: compose_rule_set.append(jiwer.ToLowerCase()) if req_data.get('strip_punctuation', False) == True: compose_rule_set.append(jiwer.RemovePunctuation()) if req_data.get('strip_words', False) == True: compose_rule_set.append(jiwer.Strip()) if req_data.get('strip_multi_space', False) == True: compose_rule_set.append(jiwer.RemoveMultipleSpaces()) word_excepts = req_data.get('t_words', '') if word_excepts != '': words = [a.strip() for a in word_excepts.split(",")] compose_rule_set.append(jiwer.RemoveSpecificWords(words)) compose_rule_set.append( jiwer.RemoveWhiteSpace( replace_by_space=req_data.get('replace_whitespace', False))) transformation = jiwer.Compose(compose_rule_set) measures = jiwer.compute_measures(req_data.get('s_truth', ""), req_data.get('s_hypo', ""), truth_transform=transformation, hypothesis_transform=transformation) return jsonify({ "wer": measures['wer'], "mer": measures['mer'], "wil": measures['wil'] }) except: return jsonify("API endpoint Error")
if (tstart - end) > 0.5: srt.push(next_sub) break end = next_sub.end.hours * 3600 + next_sub.end.minutes * 60 + next_sub.end.seconds + next_sub.end.milliseconds / 1000 ground_truth = ground_truth + " " + next_sub.text_without_tags hypothesis = kd.query_text(start, end) else: break kd.mark_words(start, end) transformation = jiwer.Compose([ jiwer.ToLowerCase(), jiwer.RemoveMultipleSpaces(), jiwer.RemoveWhiteSpace(replace_by_space=True), jiwer.SentencesToListOfWords(), jiwer.RemovePunctuation(), jiwer.RemoveEmptyStrings(), jiwer.SubstituteRegexes({r"ё": r"е"}) ]) gt = transformation([ground_truth]) hp = transformation([hypothesis]) gt, hp = replace_pairs(gt, hp) hp, gt = replace_pairs(hp, gt) wer(gt, hp) r = jiwer.compute_measures( gt, hp
# jiwer.RemovePunctuation removes string.punctuation not all Unicode punctuation class RemovePunctuation(jiwer.AbstractTransform): def process_string(self, s: str): return regex.sub(r"\p{P}", "", s) # remove some differences that we don't care about for comparisons transform = jiwer.Compose([ jiwer.ToLowerCase(), RemovePunctuation(), jiwer.SubstituteRegexes( {r"\b(uh|um|ah|hi|alright|all right|well|kind of)\b": ""}), jiwer.SubstituteWords({ "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10", "plus": "+", "minus": "-", "check out": "checkout", "hard point": "hardpoint"}), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings() ]) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("json_path") parser.add_argument("--verbose", action="store_true") parser.add_argument("--cleaned", action="store_true") args = parser.parse_args()
for i in range(len(transcript_files)): f = os.path.basename(transcript_files[i]).split(".")[0] orig = data["Transcription "][int(f)] orig = orig.replace("@", "") orig = orig.replace("_", "") orig = orig.strip() if orig != "": f = open(transcript_files[i],"r") tr = json.load(f) pred = tr["transcriptions"][0]["utf_text"] pred = pred.replace("@", "") pred = pred.replace("_", "") orig = orig.strip() transformation = jiwer.Compose([ jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings() ]) error = wer(orig, pred,truth_transform=transformation,hypothesis_transform=transformation) wer_score.append(error) audio_id.append(f) # exit() print("min: {}, mean: {}, max: {}".format(min(wer_score), sum(wer_score)/len(wer_score), max(wer_score))) wer_result = pd.DataFrame() wer_result["ID"] = audio_id wer_result["WER"] = wer_score wer_result.to_csv("WER_result.csv", index=False, header=True)
import json import jiwer transformation = jiwer.Compose([ jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings(), jiwer.ToLowerCase(), jiwer.RemovePunctuation() ]) def calculate_wer(line: str) -> float: json_obj = json.loads(line) label2 = json_obj.get('label') # label2 = label2.replace('-', ' ') infer = json_obj['yitu_infer'] wer_score = jiwer.wer(infer, label2, truth_transform=transformation, hypothesis_transform=transformation) json_obj['wer'] = wer_score return json.dumps(json_obj)
content = content[:begin] + redacted_text + content[end:] tokens = nlp.tokenize(content) pos_tags = [tag for _, tag in nlp.pos_tag(tokens)] output = { "char_length": len(content), "num_tokens": len(tokens), "token_lengths": [len(token) for token in tokens], "pos_tags": pos_tags, } if hypothesis_transcript is not None: output["wer"] = wer(content, hypothesis_transcript) return output JIWER_TRANSFORM = jiwer.Compose([ jiwer.RemovePunctuation(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings() ]) def wer(ref_string, string): ref_speech_content = extract_speech_content(ref_string) speech_content = extract_speech_content(string) return jiwer.wer(ref_speech_content, speech_content, truth_transform=JIWER_TRANSFORM, hypothesis_transform=JIWER_TRANSFORM)
punctuations = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' def replace_punctuations_by_space(text): for p in punctuations: text = text.replace(p, ' ') return text def RemovePunctuation(replace_by_space=False): if not replace_by_space: return jiwer.RemovePunctuation() return replace_punctuations_by_space transformation = jiwer.Compose([ jiwer.ToLowerCase(), RemovePunctuation(replace_by_space=False), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.RemoveEmptyStrings(), jiwer.SentencesToListOfWords(), jiwer.RemoveWhiteSpace(replace_by_space=False), jiwer.RemoveEmptyStrings(), ]) def compute_avg_wer(ground_truth, hypothesis): assert len(ground_truth) == len(hypothesis) wer_sum = 0 for gt, h in zip(ground_truth, hypothesis): wer_score = jiwer.wer( gt, h, truth_transform=transformation, hypothesis_transform=transformation
import os from jiwer import wer import jiwer path_gt = './lines_test/' # Ground truth- Original Text path_trans = './noisy_lines_test_output/' # Text from OCR list_articles = os.listdir(path_gt) e = 0 trans = jiwer.Compose([ jiwer.RemoveMultipleSpaces(), jiwer.RemoveWhiteSpace(replace_by_space=False) ]) l = len(list_articles) print('Total Number of lines are : ' + str(l)) i = 0 for filename in list_articles: f_gt = open(os.path.join(path_gt, filename), 'r', encoding="utf-8") f_tr = open(os.path.join(path_trans, filename), 'r', encoding="utf-8") doc_gt = f_gt.read() doc_tr = f_tr.read() #print(filename +' ....' + str(i)) if (i % 1000): print(i, e / i) i += 1 if (os.path.getsize(os.path.join(path_gt, filename)) <= 1 or len(doc_gt) == 0): continue if (len(doc_tr) == 0):
def forward(self, wav): input_values = self.processor(wav, return_tensors='pt', padding='longest', sampling_rate=16000).input_values logits = self.model(input_values.to(self.model.device)).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = self.processor.batch_decode(predicted_ids) return transcription _wer_trans = jiwer.Compose([ jiwer.ToUpperCase(), jiwer.ExpandCommonEnglishContractions(), jiwer.RemovePunctuation(), jiwer.RemoveMultipleSpaces(), jiwer.Strip(), jiwer.SentencesToListOfWords(), jiwer.RemoveEmptyStrings(), ]) def _eval(batch, metrics, including='output', sample_rate=8000, use_pypesq=False): if use_pypesq: metrics = [m for m in metrics if m != 'pesq'] has_estoi = False