def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id): for hypo in hypos[:min(len(hypos), args.nbest)]: hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: hyp_words = post_process(hyp_pieces, args.remove_bpe) if res_files is not None: print("{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"]) print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) tgt_pieces = tgt_dict.string(target_tokens) tgt_words = post_process(tgt_pieces, args.remove_bpe) if res_files is not None: print("{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"]) print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) # only score top hypothesis if not args.quiet: logger.debug("HYPO:" + hyp_words) logger.debug("TARGET:" + tgt_words) logger.debug("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def postprocess_tokens(self, task, target, hypo_tokens): tgt_tokens = utils.strip_pad(target, task.target_dictionary.pad()).int().cpu() tgt_str = task.target_dictionary.string(tgt_tokens) tgt_str = post_process(tgt_str, "letter") hypo_pieces = task.target_dictionary.string(hypo_tokens) hypo_str = post_process(hypo_pieces, "letter") return tgt_str, hypo_str
def process_sentence( self, sample: Dict[str, Any], hypo: Dict[str, Any], sid: int, batch_id: int, ) -> Tuple[int, int]: speaker = None # Speaker can't be parsed from dataset. if "target_label" in sample: toks = sample["target_label"] else: toks = sample["target"] toks = toks[batch_id, :] # Processes hypothesis. hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu()) if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: hyp_words = post_process(hyp_pieces, self.cfg.common_eval.post_process) # Processes target. target_tokens = utils.strip_pad(toks, self.tgt_dict.pad()) tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu()) tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process) if self.cfg.decoding.results_path is not None: print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file) print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file) print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file) print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file) if not self.cfg.common_eval.quiet: logger.info(f"HYPO: {hyp_words}") logger.info(f"REF: {tgt_words}") logger.info("---------------------") hyp_words, tgt_words = hyp_words.split(), tgt_words.split() return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def _inference_with_wer(self, generator, sample, model): def decode(toks, escape_unk=True): s = self.target_dictionary.string( toks.int().cpu(), self.args.eval_wer_remove_bpe, escape_unk=escape_unk, extra_symbols_to_ignore={generator.eos}, ) if self.tokenizer: s = self.tokenizer.decode(s) return s num_word_errors, num_char_errors = 0, 0 num_chars, num_words = 0, 0 gen_out = self.inference_step(generator, [model], sample, None) for i in range(len(gen_out)): hyp = decode(gen_out[i][0]["tokens"]) ref = decode( utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), escape_unk=True, ) hyp = post_process(hyp, self.args.eval_wer_remove_bpe).strip("_") ref = post_process(ref, self.args.eval_wer_remove_bpe).strip("_") num_char_errors += editdistance.eval(hyp, ref) num_chars += len(ref) hyp_words = hyp.split("_") ref_words = ref.split("_") num_word_errors += editdistance.eval(hyp_words, ref_words) num_words += len(ref_words) return { "num_char_errors": num_char_errors, "num_chars": num_chars, "num_word_errors": num_word_errors, "num_words": num_words, }
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id): for hypo in hypos[:min(len(hypos), args.nbest)]: hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu()) if "words" in hypo: hyp_words = " ".join(hypo["words"]) else: hyp_words = post_process(hyp_pieces, args.post_process) if res_files is not None: print( "{} ({}-{})".format(hyp_pieces, speaker, id), file=res_files["hypo.units"], ) print( "{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"], ) tgt_pieces = tgt_dict.string(target_tokens) tgt_words = post_process(tgt_pieces, args.post_process) if res_files is not None: print( "{} ({}-{})".format(tgt_pieces, speaker, id), file=res_files["ref.units"], ) print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) # only score top hypothesis hyp_words = hyp_words.split() tgt_words = tgt_words.split() #return editdistance.eval(hyp_words, tgt_words), len(tgt_words) return 0, len(tgt_words)
def string( self, tensor, bpe_symbol=None, escape_unk=False, extra_symbols_to_ignore=None, unk_string=None, include_eos=False, separator=" ", ): """Helper for converting a tensor of token indices to a string. Can optionally remove BPE symbols or escape <unk> words. """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( self.string( t, bpe_symbol, escape_unk, extra_symbols_to_ignore, include_eos=include_eos, ) for t in tensor ) extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) extra_symbols_to_ignore.add(self.eos()) def token_string(i): if i == self.unk(): if unk_string is not None: return unk_string else: return self.unk_string(escape_unk) else: return self[i] if hasattr(self, "bos_index"): extra_symbols_to_ignore.add(self.bos()) sent = separator.join( token_string(i) for i in tensor if utils.item(i) not in extra_symbols_to_ignore ) return data_utils.post_process(sent, bpe_symbol)
def predict(self, wav_path): sample = dict() net_input = dict() feature = self._get_feature(wav_path) net_input["source"] = feature.unsqueeze(0) padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0) net_input["padding_mask"] = padding_mask sample["net_input"] = net_input with torch.no_grad(): hypo = self._generator.generate([ self._model ], sample, prefix_tokens=None) hyp_pieces = self._target_dict.string(hypo[0][0]["tokens"].int().cpu()) return post_process(hyp_pieces, 'letter')
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id, trans): for hypo in hypos[:min(len(hypos), 1)]: hyp_words = [] if "words" in hypo: hyp_words = " ".join(hypo["words"]) for hyp_word in hypo["words"]: if '[' not in hyp_word and '<' not in hyp_word: continue hyp_words.append(hyp_word) else: hyp_pieces = tgt_dict.tokenizer.decode(hypo["tokens"].int().cpu()) for hypo_chr in hyp_pieces.split(): if '[' not in hypo_chr and '<' not in hypo_chr: hyp_words.append(hypo_chr) tgt_words = [] for tgt_word in trans.strip().split(): if '[' not in tgt_word and '<' not in tgt_word: tgt_words.append(tgt_word) tgt_words = ' '.join(tgt_words) hyp_words = post_process(' '.join(hyp_words)) if args.iscn: hyp_words = ' '.join(list(hyp_words.replace(' ', ''))) tgt_words = ' '.join(list(tgt_words.replace(' ', ''))) if res_files is not None: print("{} ({}-{})".format(hyp_words, speaker, id), file=res_files["hypo.words"]) print("{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]) # only score top hypothesis if not args.quiet: logger.debug("HYPO:" + hyp_words) logger.debug("TARGET:" + tgt_words) logger.debug("___________________") hyp_words = hyp_words.split() tgt_words = tgt_words.split() return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) lprobs = model.get_normalized_probs( net_output, log_probs=True).contiguous() # (T, B, C) from the encoder if "src_lengths" in sample["net_input"]: input_lengths = sample["net_input"]["src_lengths"] else: non_padding_mask = ~net_output["padding_mask"] input_lengths = non_padding_mask.long().sum(-1) pad_mask = (sample["target"] != self.pad_idx) & (sample["target"] != self.eos_idx) targets_flat = sample["target"].masked_select(pad_mask) target_lengths = sample["target_lengths"] with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss( lprobs, targets_flat, input_lengths, target_lengths, blank=self.blank_idx, reduction="sum", zero_infinity=self.zero_infinity, ) ntokens = (sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()) sample_size = sample["target"].size( 0) if self.sentence_avg else ntokens logging_output = { "loss": utils.item(loss.data), # * sample['ntokens'], "ntokens": ntokens, "nsentences": sample["id"].numel(), "sample_size": sample_size, } if not model.training: import editdistance with torch.no_grad(): lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu() c_err = 0 c_len = 0 w_errs = 0 w_len = 0 wv_errs = 0 for lp, t, inp_l in zip( lprobs_t, sample["target_label"] if "target_label" in sample else sample["target"], input_lengths, ): lp = lp[:inp_l].unsqueeze(0) decoded = None if self.w2l_decoder is not None: decoded = self.w2l_decoder.decode(lp) if len(decoded) < 1: decoded = None else: decoded = decoded[0] if len(decoded) < 1: decoded = None else: decoded = decoded[0] p = (t != self.task.target_dictionary.pad()) & ( t != self.task.target_dictionary.eos()) targ = t[p] targ_units = self.task.target_dictionary.string(targ) targ_units_arr = targ.tolist() toks = lp.argmax(dim=-1).unique_consecutive() pred_units_arr = toks[toks != self.blank_idx].tolist() c_err += editdistance.eval(pred_units_arr, targ_units_arr) c_len += len(targ_units_arr) targ_words = post_process(targ_units, self.post_process).split() pred_units = self.task.target_dictionary.string( pred_units_arr) pred_words_raw = post_process(pred_units, self.post_process).split() if decoded is not None and "words" in decoded: pred_words = decoded["words"] w_errs += editdistance.eval(pred_words, targ_words) wv_errs += editdistance.eval(pred_words_raw, targ_words) else: dist = editdistance.eval(pred_words_raw, targ_words) w_errs += dist wv_errs += dist w_len += len(targ_words) logging_output["wv_errors"] = wv_errs logging_output["w_errors"] = w_errs logging_output["w_total"] = w_len logging_output["c_errors"] = c_err logging_output["c_total"] = c_len return loss, sample_size, logging_output
def process_predictions( cfg: UnsupGenerateConfig, hypos, tgt_dict, target_tokens, res_files, ): retval = [] word_preds = [] transcriptions = [] dec_scores = [] for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]): if torch.is_tensor(hypo["tokens"]): tokens = hypo["tokens"].int().cpu() tokens = tokens[tokens >= tgt_dict.nspecial] hyp_pieces = tgt_dict.string(tokens) else: hyp_pieces = " ".join(hypo["tokens"]) if "words" in hypo and len(hypo["words"]) > 0: hyp_words = " ".join(hypo["words"]) else: hyp_words = post_process(hyp_pieces, cfg.post_process) to_write = {} if res_files is not None: to_write[res_files["hypo.units"]] = hyp_pieces to_write[res_files["hypo.words"]] = hyp_words tgt_words = "" if target_tokens is not None: if isinstance(target_tokens, str): tgt_pieces = tgt_words = target_tokens else: tgt_pieces = tgt_dict.string(target_tokens) tgt_words = post_process(tgt_pieces, cfg.post_process) if res_files is not None: to_write[res_files["ref.units"]] = tgt_pieces to_write[res_files["ref.words"]] = tgt_words if not cfg.fairseq.common_eval.quiet: logger.info(f"HYPO {i}:" + hyp_words) if tgt_words: logger.info("TARGET:" + tgt_words) if "am_score" in hypo and "lm_score" in hypo: logger.info( f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}" ) elif "score" in hypo: logger.info(f"DECODER SCORE: {hypo['score']}") logger.info("___________________") hyp_words_arr = hyp_words.split() tgt_words_arr = tgt_words.split() retval.append( ( editdistance.eval(hyp_words_arr, tgt_words_arr), len(hyp_words_arr), len(tgt_words_arr), hyp_pieces, hyp_words, ) ) word_preds.append(hyp_words_arr) transcriptions.append(to_write) dec_scores.append(-hypo.get("score", 0)) # negate cuz kaldi returns NLL if len(retval) > 1: best = None for r, t in zip(retval, transcriptions): if best is None or r[0] < best[0][0]: best = r, t for dest, tran in best[1].items(): print(tran, file=dest) dest.flush() return best[0] assert len(transcriptions) == 1 for dest, tran in transcriptions[0].items(): print(tran, file=dest) return retval[0]