コード例 #1
0
ファイル: infer.py プロジェクト: mtanana/LyssnFairSeq
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id):
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())

        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, args.remove_bpe)

        if res_files is not None:
            print("{} ({}-{})".format(hyp_pieces, speaker, id),
                  file=res_files["hypo.units"])
            print("{} ({}-{})".format(hyp_words, speaker, id),
                  file=res_files["hypo.words"])

        tgt_pieces = tgt_dict.string(target_tokens)
        tgt_words = post_process(tgt_pieces, args.remove_bpe)

        if res_files is not None:
            print("{} ({}-{})".format(tgt_pieces, speaker, id),
                  file=res_files["ref.units"])
            print("{} ({}-{})".format(tgt_words, speaker, id),
                  file=res_files["ref.words"])
            # only score top hypothesis
            if not args.quiet:
                logger.debug("HYPO:" + hyp_words)
                logger.debug("TARGET:" + tgt_words)
                logger.debug("___________________")

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
コード例 #2
0
ファイル: test_wav2vec2.py プロジェクト: ishine/fairseq
    def postprocess_tokens(self, task, target, hypo_tokens):
        tgt_tokens = utils.strip_pad(target,
                                     task.target_dictionary.pad()).int().cpu()
        tgt_str = task.target_dictionary.string(tgt_tokens)
        tgt_str = post_process(tgt_str, "letter")

        hypo_pieces = task.target_dictionary.string(hypo_tokens)
        hypo_str = post_process(hypo_pieces, "letter")
        return tgt_str, hypo_str
コード例 #3
0
ファイル: infer.py プロジェクト: mthrok/fairseq
    def process_sentence(
        self,
        sample: Dict[str, Any],
        hypo: Dict[str, Any],
        sid: int,
        batch_id: int,
    ) -> Tuple[int, int]:
        speaker = None  # Speaker can't be parsed from dataset.

        if "target_label" in sample:
            toks = sample["target_label"]
        else:
            toks = sample["target"]
        toks = toks[batch_id, :]

        # Processes hypothesis.
        hyp_pieces = self.tgt_dict.string(hypo["tokens"].int().cpu())
        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces,
                                     self.cfg.common_eval.post_process)

        # Processes target.
        target_tokens = utils.strip_pad(toks, self.tgt_dict.pad())
        tgt_pieces = self.tgt_dict.string(target_tokens.int().cpu())
        tgt_words = post_process(tgt_pieces, self.cfg.common_eval.post_process)

        if self.cfg.decoding.results_path is not None:
            print(f"{hyp_pieces} ({speaker}-{sid})", file=self.hypo_units_file)
            print(f"{hyp_words} ({speaker}-{sid})", file=self.hypo_words_file)
            print(f"{tgt_pieces} ({speaker}-{sid})", file=self.ref_units_file)
            print(f"{tgt_words} ({speaker}-{sid})", file=self.ref_words_file)

        if not self.cfg.common_eval.quiet:
            logger.info(f"HYPO: {hyp_words}")
            logger.info(f"REF: {tgt_words}")
            logger.info("---------------------")

        hyp_words, tgt_words = hyp_words.split(), tgt_words.split()

        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
コード例 #4
0
ファイル: audio_pretraining.py プロジェクト: StatNLP/ada4asr
    def _inference_with_wer(self, generator, sample, model):
        def decode(toks, escape_unk=True):
            s = self.target_dictionary.string(
                toks.int().cpu(),
                self.args.eval_wer_remove_bpe,
                escape_unk=escape_unk,
                extra_symbols_to_ignore={generator.eos},
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        num_word_errors, num_char_errors = 0, 0
        num_chars, num_words = 0, 0
        gen_out = self.inference_step(generator, [model], sample, None)
        for i in range(len(gen_out)):
            hyp = decode(gen_out[i][0]["tokens"])
            ref = decode(
                utils.strip_pad(sample["target"][i],
                                self.target_dictionary.pad()),
                escape_unk=True,
            )
            hyp = post_process(hyp, self.args.eval_wer_remove_bpe).strip("_")
            ref = post_process(ref, self.args.eval_wer_remove_bpe).strip("_")
            num_char_errors += editdistance.eval(hyp, ref)
            num_chars += len(ref)
            hyp_words = hyp.split("_")
            ref_words = ref.split("_")
            num_word_errors += editdistance.eval(hyp_words, ref_words)
            num_words += len(ref_words)

        return {
            "num_char_errors": num_char_errors,
            "num_chars": num_chars,
            "num_word_errors": num_word_errors,
            "num_words": num_words,
        }
コード例 #5
0
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id):
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())

        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, args.post_process)

        if res_files is not None:
            print(
                "{} ({}-{})".format(hyp_pieces, speaker, id),
                file=res_files["hypo.units"],
            )
            print(
                "{} ({}-{})".format(hyp_words, speaker, id),
                file=res_files["hypo.words"],
            )

        tgt_pieces = tgt_dict.string(target_tokens)
        tgt_words = post_process(tgt_pieces, args.post_process)

        if res_files is not None:
            print(
                "{} ({}-{})".format(tgt_pieces, speaker, id),
                file=res_files["ref.units"],
            )
            print("{} ({}-{})".format(tgt_words, speaker, id),
                  file=res_files["ref.words"])
            # only score top hypothesis

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        #return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
        return 0, len(tgt_words)
コード例 #6
0
    def string(
        self,
        tensor,
        bpe_symbol=None,
        escape_unk=False,
        extra_symbols_to_ignore=None,
        unk_string=None,
        include_eos=False,
        separator=" ",
    ):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(
                    t,
                    bpe_symbol,
                    escape_unk,
                    extra_symbols_to_ignore,
                    include_eos=include_eos,
                )
                for t in tensor
            )

        extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
        extra_symbols_to_ignore.add(self.eos())

        def token_string(i):
            if i == self.unk():
                if unk_string is not None:
                    return unk_string
                else:
                    return self.unk_string(escape_unk)
            else:
                return self[i]

        if hasattr(self, "bos_index"):
            extra_symbols_to_ignore.add(self.bos())

        sent = separator.join(
            token_string(i)
            for i in tensor
            if utils.item(i) not in extra_symbols_to_ignore
        )

        return data_utils.post_process(sent, bpe_symbol)
コード例 #7
0
    def predict(self, wav_path):
        sample = dict()
        net_input = dict()

        feature = self._get_feature(wav_path)
        net_input["source"] = feature.unsqueeze(0)

        padding_mask = torch.BoolTensor(net_input["source"].size(1)).fill_(False).unsqueeze(0)

        net_input["padding_mask"] = padding_mask
        sample["net_input"] = net_input

        with torch.no_grad():
            hypo = self._generator.generate([ self._model ], sample, prefix_tokens=None)

        hyp_pieces = self._target_dict.string(hypo[0][0]["tokens"].int().cpu())
        return post_process(hyp_pieces, 'letter')
コード例 #8
0
ファイル: infer_cif_bert.py プロジェクト: eastonYi/fairseq
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id, trans):
    for hypo in hypos[:min(len(hypos), 1)]:
        hyp_words = []

        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
            for hyp_word in hypo["words"]:
                if '[' not in hyp_word and '<' not in hyp_word:
                    continue
                hyp_words.append(hyp_word)
        else:
            hyp_pieces = tgt_dict.tokenizer.decode(hypo["tokens"].int().cpu())
            for hypo_chr in hyp_pieces.split():
                if '[' not in hypo_chr and '<' not in hypo_chr:
                    hyp_words.append(hypo_chr)
        tgt_words = []
        for tgt_word in trans.strip().split():
            if '[' not in tgt_word and '<' not in tgt_word:
                tgt_words.append(tgt_word)

        tgt_words = ' '.join(tgt_words)
        hyp_words = post_process(' '.join(hyp_words))

        if args.iscn:
            hyp_words = ' '.join(list(hyp_words.replace(' ', '')))
            tgt_words = ' '.join(list(tgt_words.replace(' ', '')))

        if res_files is not None:
            print("{} ({}-{})".format(hyp_words, speaker, id),
                  file=res_files["hypo.words"])

        print("{} ({}-{})".format(tgt_words, speaker, id),
              file=res_files["ref.words"])
        # only score top hypothesis
        if not args.quiet:
            logger.debug("HYPO:" + hyp_words)
            logger.debug("TARGET:" + tgt_words)
            logger.debug("___________________")

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
コード例 #9
0
    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"])
        lprobs = model.get_normalized_probs(
            net_output,
            log_probs=True).contiguous()  # (T, B, C) from the encoder

        if "src_lengths" in sample["net_input"]:
            input_lengths = sample["net_input"]["src_lengths"]
        else:
            non_padding_mask = ~net_output["padding_mask"]
            input_lengths = non_padding_mask.long().sum(-1)

        pad_mask = (sample["target"] != self.pad_idx) & (sample["target"] !=
                                                         self.eos_idx)
        targets_flat = sample["target"].masked_select(pad_mask)
        target_lengths = sample["target_lengths"]

        with torch.backends.cudnn.flags(enabled=False):
            loss = F.ctc_loss(
                lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )

        ntokens = (sample["ntokens"]
                   if "ntokens" in sample else target_lengths.sum().item())

        sample_size = sample["target"].size(
            0) if self.sentence_avg else ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        if not model.training:
            import editdistance

            with torch.no_grad():
                lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()

                c_err = 0
                c_len = 0
                w_errs = 0
                w_len = 0
                wv_errs = 0
                for lp, t, inp_l in zip(
                        lprobs_t,
                        sample["target_label"]
                        if "target_label" in sample else sample["target"],
                        input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)

                    decoded = None
                    if self.w2l_decoder is not None:
                        decoded = self.w2l_decoder.decode(lp)
                        if len(decoded) < 1:
                            decoded = None
                        else:
                            decoded = decoded[0]
                            if len(decoded) < 1:
                                decoded = None
                            else:
                                decoded = decoded[0]

                    p = (t != self.task.target_dictionary.pad()) & (
                        t != self.task.target_dictionary.eos())
                    targ = t[p]
                    targ_units = self.task.target_dictionary.string(targ)
                    targ_units_arr = targ.tolist()

                    toks = lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()

                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)

                    targ_words = post_process(targ_units,
                                              self.post_process).split()

                    pred_units = self.task.target_dictionary.string(
                        pred_units_arr)
                    pred_words_raw = post_process(pred_units,
                                                  self.post_process).split()

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw,
                                                     targ_words)
                    else:
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist

                    w_len += len(targ_words)

                logging_output["wv_errors"] = wv_errs
                logging_output["w_errors"] = w_errs
                logging_output["w_total"] = w_len
                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output
コード例 #10
0
ファイル: w2vu_generate.py プロジェクト: lingss0918/fairseq
def process_predictions(
    cfg: UnsupGenerateConfig,
    hypos,
    tgt_dict,
    target_tokens,
    res_files,
):
    retval = []
    word_preds = []
    transcriptions = []
    dec_scores = []

    for i, hypo in enumerate(hypos[: min(len(hypos), cfg.nbest)]):
        if torch.is_tensor(hypo["tokens"]):
            tokens = hypo["tokens"].int().cpu()
            tokens = tokens[tokens >= tgt_dict.nspecial]
            hyp_pieces = tgt_dict.string(tokens)
        else:
            hyp_pieces = " ".join(hypo["tokens"])

        if "words" in hypo and len(hypo["words"]) > 0:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, cfg.post_process)

        to_write = {}
        if res_files is not None:
            to_write[res_files["hypo.units"]] = hyp_pieces
            to_write[res_files["hypo.words"]] = hyp_words

        tgt_words = ""
        if target_tokens is not None:
            if isinstance(target_tokens, str):
                tgt_pieces = tgt_words = target_tokens
            else:
                tgt_pieces = tgt_dict.string(target_tokens)
                tgt_words = post_process(tgt_pieces, cfg.post_process)

            if res_files is not None:
                to_write[res_files["ref.units"]] = tgt_pieces
                to_write[res_files["ref.words"]] = tgt_words

        if not cfg.fairseq.common_eval.quiet:
            logger.info(f"HYPO {i}:" + hyp_words)
            if tgt_words:
                logger.info("TARGET:" + tgt_words)

            if "am_score" in hypo and "lm_score" in hypo:
                logger.info(
                    f"DECODER AM SCORE: {hypo['am_score']}, DECODER LM SCORE: {hypo['lm_score']}, DECODER SCORE: {hypo['score']}"
                )
            elif "score" in hypo:
                logger.info(f"DECODER SCORE: {hypo['score']}")

            logger.info("___________________")

        hyp_words_arr = hyp_words.split()
        tgt_words_arr = tgt_words.split()

        retval.append(
            (
                editdistance.eval(hyp_words_arr, tgt_words_arr),
                len(hyp_words_arr),
                len(tgt_words_arr),
                hyp_pieces,
                hyp_words,
            )
        )
        word_preds.append(hyp_words_arr)
        transcriptions.append(to_write)
        dec_scores.append(-hypo.get("score", 0))  # negate cuz kaldi returns NLL

    if len(retval) > 1:
        best = None
        for r, t in zip(retval, transcriptions):
            if best is None or r[0] < best[0][0]:
                best = r, t
        for dest, tran in best[1].items():
            print(tran, file=dest)
            dest.flush()
        return best[0]

    assert len(transcriptions) == 1
    for dest, tran in transcriptions[0].items():
        print(tran, file=dest)

    return retval[0]