def add_pred(df, path, decoder_level): if not path: data = df[['src_char']].reset_index(drop=True) data = data.fillna('') else: data = pd.read_csv(path, sep="\n", header=None, skip_blank_lines=False) data = data.fillna('') data.columns = ["prediction"] df = df.reset_index(drop=True) if decoder_level == 'char': df['prediction_char'] = data["prediction"] df["prediction"] = data["prediction"].apply(recover_space) else: df['prediction_char'] = data["prediction"].apply(replace_space) df["prediction"] = data["prediction"] errors, matches, ref_length = [], [], [] errors_char, matches_char, ref_length_char = [], [], [] df['entity_errors'] = 0 for index, row in df.iterrows(): # token ref_line = row['tgt_token'] hyp_line = row['prediction'] ref = ref_line.split() hyp = hyp_line.split() sm = SequenceMatcher(a=ref, b=hyp) errors.append(get_error_count(sm)) matches.append(get_match_count(sm)) ref_length.append(len(ref)) # char ref = row['tgt_char'].split() hyp = row['prediction_char'].split() sm = SequenceMatcher(a=ref, b=hyp) errors_char.append(get_error_count(sm)) matches_char.append(get_match_count(sm)) ref_length_char.append(len(ref)) # entity df.loc[ index, 'entity_errors'] = 0 # sum([not clean_string(s) in hyp_line for s in row['entities_dic'].keys()]) df['entity_count'] = df['entities_dic'].apply(len) df['token_errors'] = errors df['token_matches'] = matches df['token_length'] = ref_length df['char_errors'] = errors_char df['char_matches'] = matches_char df['char_length'] = ref_length_char df['sentence_count'] = 1 df['sentence_error'] = 0 df.loc[df['token_errors'] > 0, 'sentence_error'] = 1 return df
def test_issue4(self): """ Test for error reported here: https://github.com/belambert/edit-distance/issues/4 """ a = ['that', 'continuous', 'sanction', ':=', '(', 'flee', 'U', 'complain', ')', 'E', 'attendance', 'eye', '^', 'flowery', 'revelation', '^', 'ridiculous', 'destination', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>'] b = ['continuous', ':=', '(', 'sanction', '^', 'flee', '^', 'attendance', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>'] target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1], ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2], ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4], ['insert', 4, 4, 4, 5], ['equal', 5, 6, 5, 6], ['replace', 6, 7, 6, 7], ['replace', 7, 8, 7, 8], ['replace', 8, 9, 8, 9], ['replace', 9, 10, 9, 10], ['replace', 10, 11, 10, 11], ['replace', 11, 12, 11, 12], ['replace', 12, 13, 12, 13], ['replace', 13, 14, 13, 14], ['replace', 14, 15, 14, 15], ['replace', 15, 16, 15, 16], ['replace', 16, 17, 16, 17], ['replace', 17, 18, 17, 18], ['equal', 18, 19, 18, 19], ['equal', 19, 20, 19, 20], ['equal', 20, 21, 20, 21], ['equal', 21, 22, 21, 22], ['equal', 22, 23, 22, 23], ['equal', 23, 24, 23, 24], ['equal', 24, 25, 24, 25], ['equal', 25, 26, 25, 26], ['equal', 26, 27, 26, 27], ['equal', 27, 28, 27, 28], ['equal', 28, 29, 28, 29]] sm = SequenceMatcher(a=a, b=b) self.assertEqual(sm.distance(), 16) self.assertEqual(sm.get_opcodes(), target_opcodes)
def test_issue4_simpler(self): """ Test for error reported here: https://github.com/belambert/edit-distance/issues/4 """ a = ['that', 'continuous', 'sanction', ':=', '('] b = ['continuous', ':=', '(', 'sanction', '^'] sm = SequenceMatcher(a=a, b=b) self.assertEqual(sm.distance(), 4) target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1], ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2], ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4], ['insert', 4, 4, 4, 5]] self.assertEqual(sm.get_opcodes(), target_opcodes)
def test_unsupported(self): """Test if calling unimplemented methods actually generates an error.""" a = ['a', 'b'] b = ['a', 'b', 'd', 'c'] sm = SequenceMatcher(a=a, b=b) with self.assertRaises(NotImplementedError): sm.find_longest_match(1, 2, 3, 4) with self.assertRaises(NotImplementedError): sm.get_grouped_opcodes()
def test_sequence_matcher2(self): """Test the sequence matcher.""" a = ['a', 'b'] b = ['a', 'b', 'd', 'c'] sm = SequenceMatcher() sm.set_seq1(a) sm.set_seq2(b) self.assertTrue(sm.distance() == 2) sm.set_seqs(b, a) self.assertTrue(sm.distance() == 2)
def test_sequence_matcher2(self): """Test the sequence matcher.""" a = ["a", "b"] b = ["a", "b", "d", "c"] sm = SequenceMatcher() sm.set_seq1(a) sm.set_seq2(b) self.assertEqual(sm.distance(), 2) sm.set_seqs(b, a) self.assertEqual(sm.distance(), 2)
def test_issue13(self): sm = SequenceMatcher(a="abc", b="abdc") self.assertEqual( [ ["equal", 0, 1, 0, 1], ["equal", 1, 2, 1, 2], ["insert", 2, 2, 2, 3], ["equal", 2, 3, 3, 4], ], sm.get_opcodes(), )
def get_match_score(self, prediction, target, processing="base"): assert processing in ["base", "structural"] if processing == "structural": prediction_clean = self.clean_structural(prediction) target_clean = self.clean_structural(target) if prediction_clean == [] and target_clean == []: return 1.0 else: prediction_clean = self.clean_base(prediction) target_clean = self.clean_base(target) sm = SequenceMatcher(a=prediction_clean, b=target_clean) return sm.ratio()
def test_issue4_simpler(self): """Test for error reported here: https://github.com/belambert/edit-distance/issues/4""" a = ["that", "continuous", "sanction", ":=", "("] b = ["continuous", ":=", "(", "sanction", "^"] sm = SequenceMatcher(a=a, b=b) self.assertEqual(sm.distance(), 4) target_opcodes = [ ["delete", 0, 1, 0, 0], ["equal", 1, 2, 0, 1], ["delete", 2, 3, 0, 0], ["equal", 3, 4, 1, 2], ["equal", 4, 5, 2, 3], ["insert", 5, 5, 3, 4], ["insert", 5, 5, 4, 5], ] self.assertEqual(sm.get_opcodes(), target_opcodes)
def test_sequence_matcher(self): """Test the sequence matcher.""" a = ['a', 'b'] b = ['a', 'b', 'd', 'c'] sm = SequenceMatcher(a=a, b=b) opcodes = [['equal', 0, 1, 0, 1], ['equal', 1, 2, 1, 2], ['insert', 1, 1, 2, 3], ['insert', 1, 1, 3, 4]] self.assertTrue(sm.distance() == 2) self.assertTrue(sm.ratio() == 2 / 3) self.assertTrue(sm.quick_ratio() == 2 / 3) self.assertTrue(sm.real_quick_ratio() == 2 / 3) self.assertTrue(sm.distance() == 2) # This doesn't return anything, saves the value in the sm cache. self.assertTrue(not sm._compute_distance_fast()) self.assertTrue(sm.get_opcodes() == opcodes) self.assertTrue( list(sm.get_matching_blocks()) == [[0, 0, 1], [1, 1, 1]])
def get_match_score(self, prediction, target, processing="base"): assert processing in ["base", "structural"] if processing == "structural": prediction_clean = self.clean_structural(prediction) target_clean = self.clean_structural(target) if prediction_clean == [] and target_clean == []: return 1.0 else: prediction_clean = self.clean_base(prediction) target_clean = self.clean_base(target) sm = SequenceMatcher(a=prediction_clean, b=target_clean) # editdistance workaround on empty sequences if not prediction_clean and not target_clean: return 1 return sm.ratio()
def test_sequence_matcher(self): """Test the sequence matcher.""" a = ['a', 'b'] b = ['a', 'b', 'd', 'c'] sm = SequenceMatcher(a=a, b=b) opcodes = [['equal', 0, 1, 0, 1], ['equal', 1, 2, 1, 2], ['insert', 1, 1, 2, 3], ['insert', 1, 1, 3, 4]] self.assertTrue(sm.distance() == 2) self.assertTrue(sm.ratio() == 2 / 3) self.assertTrue(sm.quick_ratio() == 2 / 3) self.assertTrue(sm.real_quick_ratio() == 2 / 3) self.assertTrue(sm.distance() == 2) # This doesn't return anything, saves the value in the sm cache. self.assertTrue(not sm._compute_distance_fast()) self.assertTrue(sm.get_opcodes() == opcodes) self.assertTrue(list(sm.get_matching_blocks()) == [[0, 0, 1], [1, 1, 1]])
def align(src, tgt): """Corrects misalignments between the gold and predicted tokens which will almost almost always have different lengths due to inserted, deleted, or substituted tookens in the predicted systme output.""" sm = SequenceMatcher( a=list(map(lambda x: x[0], tgt)), b=list(map(lambda x: x[0], src))) tgt_temp, src_temp = [], [] opcodes = sm.get_opcodes() for tag, i1, i2, j1, j2 in opcodes: # If they are equal, do nothing except lowercase them if tag == 'equal': for i in range(i1, i2): tgt[i][1] = 'e' tgt_temp.append(tgt[i]) for i in range(j1, j2): src[i][1] = 'e' src_temp.append(src[i]) # For insertions and deletions, put a filler of '***' on the other one, and # make the other all caps elif tag == 'delete': for i in range(i1, i2): tgt[i][1] = 'd' tgt_temp.append(tgt[i]) for i in range(i1, i2): src_temp.append(tgt[i]) elif tag == 'insert': for i in range(j1, j2): src[i][1] = 'i' tgt_temp.append(src[i]) for i in range(j1, j2): src_temp.append(src[i]) # More complicated logic for a substitution elif tag == 'replace': for i in range(i1, i2): tgt[i][1] = 's' for i in range(j1, j2): src[i][1] = 's' tgt_temp += tgt[i1:i2] src_temp += src[j1:j2] src, tgt = src_temp, tgt_temp return src, tgt
def test_sequence_matcher(self): """Test the sequence matcher.""" a = ["a", "b"] b = ["a", "b", "d", "c"] sm = SequenceMatcher(a=a, b=b) opcodes = [ ["equal", 0, 1, 0, 1], ["equal", 1, 2, 1, 2], ["insert", 2, 2, 2, 3], ["insert", 2, 2, 3, 4], ] self.assertEqual(sm.distance(), 2) self.assertEqual(sm.ratio(), 2 / 3) self.assertEqual(sm.quick_ratio(), 2 / 3) self.assertEqual(sm.real_quick_ratio(), 2 / 3) self.assertEqual(sm.distance(), 2) # This doesn't return anything, saves the value in the sm cache. self.assertTrue(not sm._compute_distance_fast()) self.assertEqual(sm.get_opcodes(), opcodes) self.assertEqual(list(sm.get_matching_blocks()), [[0, 0, 1], [1, 1, 1]])
def test_issue4(self): """ Test for error reported here: https://github.com/belambert/edit-distance/issues/4 """ a = [ 'that', 'continuous', 'sanction', ':=', '(', 'flee', 'U', 'complain', ')', 'E', 'attendance', 'eye', '^', 'flowery', 'revelation', '^', 'ridiculous', 'destination', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>' ] b = [ 'continuous', ':=', '(', 'sanction', '^', 'flee', '^', 'attendance', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>' ] target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1], ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2], ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4], ['insert', 4, 4, 4, 5], ['equal', 5, 6, 5, 6], ['replace', 6, 7, 6, 7], ['replace', 7, 8, 7, 8], ['replace', 8, 9, 8, 9], ['replace', 9, 10, 9, 10], ['replace', 10, 11, 10, 11], ['replace', 11, 12, 11, 12], ['replace', 12, 13, 12, 13], ['replace', 13, 14, 13, 14], ['replace', 14, 15, 14, 15], ['replace', 15, 16, 15, 16], ['replace', 16, 17, 16, 17], ['replace', 17, 18, 17, 18], ['equal', 18, 19, 18, 19], ['equal', 19, 20, 19, 20], ['equal', 20, 21, 20, 21], ['equal', 21, 22, 21, 22], ['equal', 22, 23, 22, 23], ['equal', 23, 24, 23, 24], ['equal', 24, 25, 24, 25], ['equal', 25, 26, 25, 26], ['equal', 26, 27, 26, 27], ['equal', 27, 28, 27, 28], ['equal', 28, 29, 28, 29]] sm = SequenceMatcher(a=a, b=b) self.assertEqual(sm.distance(), 16) self.assertEqual(sm.get_opcodes(), target_opcodes)
def print_errors(df, n, random_state=1): df = df[df.token_errors > 0][['src_token', 'tgt_token', 'prediction']] print(len(df)) df = df.sample(frac=n, random_state=random_state) for src_line, ref_line, hyp_line in zip(df['src_token'].values, df['tgt_token'].values, df['prediction'].values): ref = ref_line.split() hyp = hyp_line.split() sm = SequenceMatcher(a=ref, b=hyp) print("SRC:", src_line) print_diff(sm, ref, hyp) print()
def char_error_rate(true_labels, pred_labels, decoded=False): cers = np.empty(len(true_labels)) if not decoded: pred_labels = for_tf_or_th(pred_labels, pred_labels.swapaxes(0, 1)) i = 0 for true_label, pred_label in zip(true_labels, pred_labels): prediction = pred_label if decoded else argmax_decode(pred_label) ratio = SequenceMatcher(true_label, prediction).ratio() cers[i] = 1 - ratio i += 1 return cers
def wer(ref, hyp): ref = ref.strip().replace("'", "").split() hyp = hyp.strip().replace("'", "").split() ref = list(map(str.lower, ref)) hyp = list(map(str.lower, hyp)) ref_token_count = len(ref) sm = SequenceMatcher(a=ref, b=hyp) # error_count = get_error_count(sm) match_count = get_match_count(sm) # wrr = match_count / ref_token_count wer = 1 - match_count / ref_token_count return wer, match_count, ref_token_count
def _update_state(self, y_true, y_pred, mask=None): for i in range(len(y_true)): assert len(y_true[i]) == len(y_pred[i]) # select utterance y_true_ = y_true[i] y_pred_ = y_pred[i] # remove padding y_true_ = y_true_[mask[i]] y_pred_ = y_pred_[mask[i]] # merge consequence states y_true_ = _merge_consequent_states(y_true_) y_pred_ = _merge_consequent_states(y_pred_) # compute edit distance sm = SequenceMatcher(a=y_true_, b=y_pred_) edit_distance = sm.distance() # update state self.edit_distance.assign_add(edit_distance) self.length.assign_add(len(y_true_))
def process_line_pair(ref_line, hyp_line): """Given a pair of strings corresponding to a reference and hypothesis, compute the edit distance, print if desired, and keep track of results in global variables.""" # I don't believe these all need to be global. In any case, they shouldn't be. global error_count global match_count global ref_token_count ref = ref_line.split() hyp = hyp_line.split() id_ = None # If the files have IDs, then split the ID off from the text if files_have_ids: remove_sentence_ids(ref, hyp) # Create an object to get the edit distance, and then retrieve the # relevant counts that we need. sm = SequenceMatcher(a=ref, b=hyp) errors = get_error_count(sm) matches = get_match_count(sm) ref_length = len(ref) # Increment the total counts we're tracking error_count += errors match_count += matches ref_token_count += ref_length # If we're keeping track of which words get mixed up with which others, call track_confusions if confusions: track_confusions(sm, ref, hyp) # If we're printing instances, do it here (in roughly the align.c format) if print_instances_p: print_instances(ref, hyp, sm, id_=id_) # Keep track of the individual error rates, and reference lengths, so we # can compute average WERs by sentence length lengths.append(ref_length) if len(ref) > 0: error_rate = errors * 1.0 / len(ref) else: error_rate = float("inf") error_rates.append(error_rate) wer_bins[len(ref)].append(error_rate)
def print_diff(s1, s2, prefix1='REF:', prefix2='HYP:', suffix1=None, suffix2=None): """Print a readable diff between two sentences. This is the only place we use anything from asr-evaluation.""" a = s1.words b = s2.words sm = SequenceMatcher(a, b) eval_print_diff(sm, s1.words, s2.words, prefix1=prefix1, prefix2=prefix2, suffix1=suffix1, suffix2=suffix2)
def compute_levenshtein(self, pred, gt, max_levenshtein=5): assert len(pred) == len(gt) total = len(gt) # Maintain the count of each edit-distances counter = Counter() max_chars = 0 for ref, hyp in zip(gt, pred): max_chars += max(len(ref), len(hyp)) counter[SequenceMatcher(a=ref, b=hyp).distance()] += 1 avg_chars = max_chars / total # Compute accuracy for each edit-distance accuracies = [0] * max_levenshtein total_dist = 0 for levenshtein, count in counter.items(): if levenshtein < max_levenshtein: accuracies[levenshtein] = count * 100 / total total_dist += levenshtein * count avg_dist = total_dist / total return accuracies, avg_dist, avg_dist / avg_chars
# frame-by-frame at the phoneme level y_pred_phones = states2phones(y_pred, phones, stateList) y_true_phones = states2phones(y_true, phones, stateList) accuracy.reset_states() accuracy.update_state(y_true_phones, y_pred_phones) print('Frame-by-frame accuracy at the phoneme level: {:.2f}%'.format( accuracy.result().numpy() * 100)) plt.figure() plot_confusion_matrix(y_true_phones, y_pred_phones) plt.title('Frame-by-frame confusion matrix at the phoneme level') # PER at the state level N = 10000 # number of frames to consider (distance computation is expensive) y_pred_merged = merge_consequent_states(y_pred[:N]) y_true_merged = merge_consequent_states(y_true[:N]) sm = SequenceMatcher(a=y_true_merged, b=y_pred_merged) edit_distance = sm.distance() print('PER at the state level: {:.2f}%'.format(edit_distance / N * 100)) # PER at the phoneme level y_pred_merged = merge_consequent_states(y_pred_phones[:N]) y_true_merged = merge_consequent_states(y_true_phones[:N]) sm = SequenceMatcher(a=y_true_merged, b=y_pred_merged) edit_distance = sm.distance() print('PER at the phoneme level: {:.2f}%'.format(edit_distance / N * 100)) # posteriors for first utterance utterance = testdata[0] x, y = prepare_matrices([utterance], K, feature_type,
def test_issue4(self): """Test for error reported here: https://github.com/belambert/edit-distance/issues/4""" a = [ "that", "continuous", "sanction", ":=", "(", "flee", "U", "complain", ")", "E", "attendance", "eye", "^", "flowery", "revelation", "^", "ridiculous", "destination", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", ] # noqa b = [ "continuous", ":=", "(", "sanction", "^", "flee", "^", "attendance", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", "<EOS>", ] # noqa target_opcodes = [ ["delete", 0, 1, 0, 0], ["equal", 1, 2, 0, 1], ["delete", 2, 3, 0, 0], ["equal", 3, 4, 1, 2], ["equal", 4, 5, 2, 3], ["insert", 5, 5, 3, 4], ["insert", 5, 5, 4, 5], ["equal", 5, 6, 5, 6], ["replace", 6, 7, 6, 7], ["replace", 7, 8, 7, 8], ["replace", 8, 9, 8, 9], ["replace", 9, 10, 9, 10], ["replace", 10, 11, 10, 11], ["replace", 11, 12, 11, 12], ["replace", 12, 13, 12, 13], ["replace", 13, 14, 13, 14], ["replace", 14, 15, 14, 15], ["replace", 15, 16, 15, 16], ["replace", 16, 17, 16, 17], ["replace", 17, 18, 17, 18], ["equal", 18, 19, 18, 19], ["equal", 19, 20, 19, 20], ["equal", 20, 21, 20, 21], ["equal", 21, 22, 21, 22], ["equal", 22, 23, 22, 23], ["equal", 23, 24, 23, 24], ["equal", 24, 25, 24, 25], ["equal", 25, 26, 25, 26], ["equal", 26, 27, 26, 27], ["equal", 27, 28, 27, 28], ["equal", 28, 29, 28, 29], ] # noqa sm = SequenceMatcher(a=a, b=b) self.assertEqual(sm.distance(), 16) self.assertEqual(sm.get_opcodes(), target_opcodes)
def calculate_metrics( self, ref, rec, ignore_punct=True, label=None, print_verbosiy=0, #0=Nothing, 1=errors, 2=all exclude=None, query_keyword=None): self.counter = 0 self.insertion_table = defaultdict(int) self.deletion_table = defaultdict(int) self.substitution_table = defaultdict(int) table = str.maketrans({key: None for key in string.punctuation}) error_count = 0 match_count = 0 ref_token_count = 0 sent_error_count = 0 lengths = [] error_rates = [] wer_bins = defaultdict(list) wrr = 0.0 wer = 0.0 ser = 0.0 id_ = '' for i, (ref_line, rec_line) in enumerate(zip(ref, rec)): if ignore_punct: _ref = ref_line.translate(table).split() _rec = rec_line.translate(table).split() else: _ref = ref_line.split() _rec = rec_line.split() if label is not None: id_ = label[i] if exclude is not None: if exclude in id_: continue if self.case_lower: _ref = list(map(str.lower, _ref)) _rec = list(map(str.lower, _rec)) if query_keyword is not None: if len([i for i in query_keyword if i in _ref]) == 0: continue sm = SequenceMatcher(a=_ref, b=_rec) errors = self.get_error_count(sm) matches = self.get_match_count(sm) ref_length = len(_ref) # Increment the total counts we're tracking error_count += errors match_count += matches ref_token_count += ref_length self.counter += 1 if errors != 0: sent_error_count += 1 self.track_confusions(sm, _ref, _rec) # If we're printing instances, do it here (in roughly the align.c format) if print_verbosiy == 2 or (print_verbosiy == 1 and errors != 0): self.print_all(_ref, _rec, sm, id_=id_) lengths.append(ref_length) if len(_ref) > 0: error_rate = errors * 1.0 / len(_ref) else: error_rate = float("inf") error_rates.append(error_rate) wer_bins[len(_ref)].append(error_rate) # Compute WER and WRR if ref_token_count > 0: wrr = match_count / ref_token_count wer = error_count / ref_token_count # Compute SER if self.counter > 0: ser = sent_error_count / self.counter print(f'\nSentence count: {self.counter}') print(f'WER: {wer:0.3%} ({error_count} / {ref_token_count})') print(f'WRR: {wrr:0.3%} ({match_count} / {ref_token_count})') print(f'SER: {ser:0.3%} ({sent_error_count} / {self.counter})') return wer, wrr, ser
# pred_one_layer_trans = transcribe(y_sample, prediction_one_layer) # pred_four_layer_trans = transcribe(y_sample, prediction_four_layer) y_transcribed = transcribe(y_sample) pred_one_layer_trans = transcribe(prediction_one_layer) pred_four_layer_trans = transcribe(prediction_four_layer) fig, axs = plt.subplots(3) axs[0].set_title("Correct output, state level merged") axs[0].pcolormesh(y_transcribed.T) axs[1].set_title(name + " 1 layer") axs[1].pcolormesh(pred_one_layer_trans.T) axs[2].set_title(name + " 4 layers") axs[2].pcolormesh(pred_four_layer_trans.T) plt.show() seq1 = SequenceMatcher(y_transcribed, prediction_four_layer_merged) distance = seq1.distance() / 324 * 100 # TODO: Then measure the Phone Error Rate (PER), # that is the length normalised edit distance between the sequence # of states from the DNN and the correct transcription # TODO: Use SequenceMatcher from edit distance to quickly calculate PER # Part 4 Phenome Level edit dist # y_transcribed_merged = transcribe(y_sample_merged, y_sample_merged) # pred_one_layer_trans_merged = transcribe(y_sample_merged, prediction_one_layer_merged) # pred_four_layer_trans_merged = transcribe(y_sample_merged, prediction_four_layer_merged) y_transcribed_merged = transcribe(y_sample_merged) pred_one_layer_trans_merged = transcribe(prediction_one_layer_merged) pred_four_layer_trans_merged = transcribe(prediction_four_layer_merged)
def process_line_pair(ref_line, hyp_line, case_insensitive=False, remove_empty_refs=False): """Given a pair of strings corresponding to a reference and hypothesis, compute the edit distance, print if desired, and keep track of results in global variables. Return true if the pair was counted, false if the pair was not counted due to an empty reference string.""" # I don't believe these all need to be global. In any case, they shouldn't be. global error_count global match_count global ref_token_count global sent_error_count # Split into tokens by whitespace ref = ref_line.split() hyp = hyp_line.split() id_ = None # If the files have IDs, then split the ID off from the text if files_head_ids: id_ = ref[0] ref, hyp = remove_head_id(ref, hyp) elif files_tail_ids: id_ = ref[-1] ref, hyp = remove_tail_id(ref, hyp) if case_insensitive: ref = list(map(str.lower, ref)) hyp = list(map(str.lower, hyp)) if remove_empty_refs and len(ref) == 0: return False # Create an object to get the edit distance, and then retrieve the # relevant counts that we need. sm = SequenceMatcher(a=ref, b=hyp) errors = get_error_count(sm) matches = get_match_count(sm) ref_length = len(ref) # Increment the total counts we're tracking error_count += errors match_count += matches ref_token_count += ref_length if errors != 0: sent_error_count += 1 # If we're keeping track of which words get mixed up with which others, call track_confusions if confusions: track_confusions(sm, ref, hyp) # If we're printing instances, do it here (in roughly the align.c format) if print_instances_p or (print_errors_p and errors != 0): print_instances(ref, hyp, sm, id_=id_) # Keep track of the individual error rates, and reference lengths, so we # can compute average WERs by sentence length lengths.append(ref_length) if len(ref) > 0: error_rate = errors * 1.0 / len(ref) else: error_rate = float("inf") error_rates.append(error_rate) wer_bins[len(ref)].append(error_rate) return True