예제 #1
0
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id):
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())

        if "words" in hypo:
            hyp_words = " ".join(hypo["words"])
        else:
            hyp_words = post_process(hyp_pieces, args.remove_bpe)

        if res_files is not None:
            print("{} ({}-{})".format(hyp_pieces, speaker, id),
                  file=res_files["hypo.units"])
            print("{} ({}-{})".format(hyp_words, speaker, id),
                  file=res_files["hypo.words"])

        tgt_pieces = tgt_dict.string(target_tokens)
        tgt_words = post_process(tgt_pieces, args.remove_bpe)

        if res_files is not None:
            print("{} ({}-{})".format(tgt_pieces, speaker, id),
                  file=res_files["ref.units"])
            print("{} ({}-{})".format(tgt_words, speaker, id),
                  file=res_files["ref.words"])
            # only score top hypothesis
            if not args.quiet:
                logger.debug("HYPO:" + hyp_words)
                logger.debug("TARGET:" + tgt_words)
                logger.debug("___________________")

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
예제 #2
0
파일: infer_ma.py 프로젝트: zjc6666/wav2vec
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id, labels):
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
        if "words" in hypo:
            hypo_words = hypo["words"]
            hypo_chrs = []
            for hypo_word in hypo_words:
                if hypo_word == '[LAUGHTER]' or hypo_word == '[VOCALIZED-NOISE]' or hypo_word == '[NOISE]' \
                        or hypo_word == '!SIL' or hypo_word == '[VOCALIZED-NOISE]' or hypo_word == '[NOISE]' or hypo_word == '[LAUGHTER]' or hypo_word == '<UNK>':
                    continue
                for chr in hypo_word:
                    hypo_chrs.append(chr)
            hyp_words = " ".join(hypo_chrs)
        else:
            hyp_words = post_process(hyp_pieces, args.remove_bpe)

        if res_files is not None:
            print("{} ({}-{})".format(hyp_pieces, speaker, id),
                  file=res_files["hypo.units"])
            print("{} ({}-{})".format(hyp_words, speaker, id),
                  file=res_files["hypo.words"])

        tgt_pieces = tgt_dict.string(target_tokens)
        # tgt_words = post_process(tgt_pieces, args.remove_bpe)
        tgt_words = post_process(labels[id], args.remove_bpe)
        import re
        tgt_chrs = []
        tgt_words = re.split('\s+', tgt_words)
        for tgt_word in tgt_words:
            if tgt_word == '[LAUGHTER]' or tgt_word == '[VOCALIZED-NOISE]' or tgt_word == '[NOISE]' \
                    or tgt_word == '!SIL' or tgt_word == '[VOCALIZED-NOISE]' or tgt_word == '[NOISE]' or tgt_word == '[LAUGHTER]' or tgt_word == '<UNK>':
                continue
            for chr in tgt_word:
                tgt_chrs.append(chr)
        tgt_words = ' '.join(tgt_chrs)

        if res_files is not None:
            print("{} ({}-{})".format(tgt_pieces, speaker, id),
                  file=res_files["ref.units"])
            print("{} ({}-{})".format(tgt_words, speaker, id),
                  file=res_files["ref.words"])
            # only score top hypothesis
            if not args.quiet:
                logger.debug("HYPO:" + hyp_words)
                logger.debug("TARGET:" + tgt_words)
                logger.debug("___________________")

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
예제 #3
0
파일: infer.py 프로젝트: zjc6666/wav2vec
def process_predictions(args, hypos, sp, tgt_dict, target_tokens, res_files,
                        speaker, id, labels):
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hyp_words = []
        if "words" in hypo:
            for hyp_word in hypo["words"]:
                if hyp_word in list_ignore:
                    continue
                hyp_words.append(hyp_word)
        else:
            hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
            for hypo_chr in hyp_pieces.split():
                if hypo_chr not in list_ignore:
                    hyp_words.append(hypo_chr)

        tgt_words = []
        for tgt_word in labels[id].strip().split():
            if tgt_word not in list_ignore:
                tgt_words.append(tgt_word)

        tgt_words = ' '.join(tgt_words)
        hyp_words = post_process(' '.join(hyp_words), args.labels)

        if args.iscn:
            hyp_words = ' '.join(list(hyp_words.replace(' ', '')))
            tgt_words = ' '.join(list(tgt_words.replace(' ', '')))

        if res_files is not None:
            print("{} ({}-{})".format(hyp_words, speaker, id),
                  file=res_files["hypo.words"])

            print("{} ({}-{})".format(tgt_words, speaker, id),
                  file=res_files["ref.words"])
            # only score top hypothesis
            if not args.quiet:
                logger.debug("HYPO:" + hyp_words)
                logger.debug("TARGET:" + tgt_words)
                logger.debug("___________________")

        hyp_words = hyp_words.split()
        tgt_words = tgt_words.split()
        return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
예제 #4
0
    def string(
        self,
        tensor,
        bpe_symbol=None,
        escape_unk=False,
        extra_symbols_to_ignore=None,
        unk_string=None,
    ):
        """Helper for converting a tensor of token indices to a string.

        Can optionally remove BPE symbols or escape <unk> words.
        """
        if torch.is_tensor(tensor) and tensor.dim() == 2:
            return "\n".join(
                self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
                for t in tensor
            )

        extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
        extra_symbols_to_ignore.add(self.eos())

        def token_string(i):
            if i == self.unk():
                if unk_string is not None:
                    return unk_string
                else:
                    return self.unk_string(escape_unk)
            else:
                return self[i]

        if hasattr(self, "bos_index"):
            extra_symbols_to_ignore.add(self.bos())

        sent = " ".join(
            token_string(i)
            for i in tensor
            if utils.item(i) not in extra_symbols_to_ignore
        )

        return data_utils.post_process(sent, bpe_symbol)
예제 #5
0
    def forward(self, model, sample, reduce=True):
        net_output = model(**sample["net_input"])
        lprobs = model.get_normalized_probs(
            net_output, log_probs=True
        ).contiguous()  # (T, B, C) from the encoder

        if "src_lengths" in sample["net_input"]:
            input_lengths = sample["net_input"]["src_lengths"]
        else:
            non_padding_mask = ~net_output["padding_mask"]
            input_lengths = non_padding_mask.long().sum(-1)

        pad_mask = (sample["target"] != self.pad_idx) & (
            sample["target"] != self.eos_idx
        )
        targets_flat = sample["target"].masked_select(pad_mask)
        target_lengths = sample["target_lengths"]

        with torch.backends.cudnn.flags(enabled=False):
            loss = F.ctc_loss(
                lprobs,
                targets_flat,
                input_lengths,
                target_lengths,
                blank=self.blank_idx,
                reduction="sum",
                zero_infinity=self.zero_infinity,
            )

        ntokens = (
            sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
        )

        sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
        logging_output = {
            "loss": utils.item(loss.data),  # * sample['ntokens'],
            "ntokens": ntokens,
            "nsentences": sample["id"].numel(),
            "sample_size": sample_size,
        }

        if not model.training:
            import editdistance

            with torch.no_grad():
                lprobs_t = lprobs.transpose(0, 1).float().cpu()

                c_err = 0
                c_len = 0
                w_errs = 0
                w_len = 0
                wv_errs = 0
                for lp, t, inp_l in zip(
                    lprobs_t,
                    sample["target_label"]
                    if "target_label" in sample
                    else sample["target"],
                    input_lengths,
                ):
                    lp = lp[:inp_l].unsqueeze(0)

                    decoded = None
                    if self.w2l_decoder is not None:
                        decoded = self.w2l_decoder.decode(lp)
                        if len(decoded) < 1:
                            decoded = None
                        else:
                            decoded = decoded[0]
                            if len(decoded) < 1:
                                decoded = None
                            else:
                                decoded = decoded[0]

                    p = (t != self.task.dictionary.pad()) & (
                        t != self.task.dictionary.eos()
                    )
                    targ = t[p]
                    targ_units = self.task.dictionary.string(targ)
                    targ_units_arr = targ.tolist()

                    toks = lp.argmax(dim=-1).unique_consecutive()
                    pred_units_arr = toks[toks != self.blank_idx].tolist()

                    c_err += editdistance.eval(pred_units_arr, targ_units_arr)
                    c_len += len(targ_units_arr)

                    targ_words = post_process(targ_units, self.post_process).split()

                    pred_units = self.task.dictionary.string(pred_units_arr)
                    pred_words_raw = post_process(pred_units, self.post_process).split()

                    if decoded is not None and "words" in decoded:
                        pred_words = decoded["words"]
                        w_errs += editdistance.eval(pred_words, targ_words)
                        wv_errs += editdistance.eval(pred_words_raw, targ_words)
                    else:
                        dist = editdistance.eval(pred_words_raw, targ_words)
                        w_errs += dist
                        wv_errs += dist

                    w_len += len(targ_words)

                logging_output["wv_errors"] = wv_errs
                logging_output["w_errors"] = w_errs
                logging_output["w_total"] = w_len
                logging_output["c_errors"] = c_err
                logging_output["c_total"] = c_len

        return loss, sample_size, logging_output