def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None): if not isinstance(utt_id, str): raise TypeError('utt_id must be a string(got {})'.format( type(utt_id))) if not isinstance(ref, str): raise TypeError('ref must be a string (got {})'.format(type(ref))) if not isinstance(pred, str): raise TypeError('pred must be a string(got {})'.format(type(pred))) # filter out any non_lang_syms from ref and pred non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None) assert non_lang_syms is None or isinstance(non_lang_syms, list) if non_lang_syms is not None and len(non_lang_syms) > 0: ref_list, pred_list = ref.strip().split(), pred.strip().split() ref = ' '.join([x for x in ref_list if x not in non_lang_syms]) pred = ' '.join([x for x in pred_list if x not in non_lang_syms]) # char level counts _, _, counter = speech_utils.edit_distance( ref.strip().split(), pred.strip().split(), ) self.char_counter += counter # word level counts ref_words = self.dictionary.tokens_to_sentence(ref, use_unk_sym=False, bpe_symbol=bpe_symbol) pred_words = self.dictionary.tokens_to_sentence(pred, bpe_symbol=bpe_symbol) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: ref_words = re.sub(pattern, repl, ref_words) pred_words = re.sub(pattern, repl, pred_words) ref_word_list, pred_word_list = ref_words.split(), pred_words.split() _, steps, counter = speech_utils.edit_distance( ref_word_list, pred_word_list, ) self.word_counter += counter assert utt_id not in self.aligned_results, \ 'Duplicated utterance id detected: {}'.format(utt_id) self.aligned_results[utt_id] = speech_utils.aligned_print( ref_word_list, pred_word_list, steps, )
def add_evaluation(self, utt_id, ref, pred): if not isinstance(utt_id, str): raise TypeError("utt_id must be a string(got {})".format( type(utt_id))) if not isinstance(ref, str): raise TypeError("ref must be a string (got {})".format(type(ref))) if not isinstance(pred, str): raise TypeError("pred must be a string(got {})".format(type(pred))) # filter out any non_lang_syms from ref and pred non_lang_syms = getattr(self.dictionary, "non_lang_syms", None) assert non_lang_syms is None or isinstance(non_lang_syms, list) if non_lang_syms is not None and len(non_lang_syms) > 0: ref_list, pred_list = ref.strip().split(), pred.strip().split() ref = " ".join([x for x in ref_list if x not in non_lang_syms]) pred = " ".join([x for x in pred_list if x not in non_lang_syms]) # char level counts _, _, counter = speech_utils.edit_distance( ref.strip().split(), pred.strip().split(), ) self.char_counter += counter # word level counts ref_words = self.dictionary.wordpiece_decode(ref) pred_words = self.dictionary.wordpiece_decode(pred) # filter words according to self.word_filters (support re.sub only) for pattern, repl in self.word_filters: ref_words = re.sub(pattern, repl, ref_words) pred_words = re.sub(pattern, repl, pred_words) ref_word_list, pred_word_list = ref_words.split(), pred_words.split() _, steps, counter = speech_utils.edit_distance( ref_word_list, pred_word_list, ) self.word_counter += counter assert (utt_id not in self.aligned_results ), "Duplicated utterance id detected: {}".format(utt_id) self.aligned_results[utt_id] = speech_utils.aligned_print( ref_word_list, pred_word_list, steps, )