コード例 #1
0
def evaluate(model, vocab, test_loader, args, lm=None, start_token=-1):
    """
    Evaluation
    args:
        model: Model object
        test_loader: DataLoader object
    """
    model.eval()

    total_word, total_char, total_cer, total_wer = 0, 0, 0, 0
    total_en_cer, total_zh_cer, total_en_char, total_zh_char = 0, 0, 0, 0
    total_hyp_char = 0
    total_time = 0

    with torch.no_grad():
        test_pbar = tqdm(iter(test_loader), leave=False, total=len(test_loader))
        for i, (data) in enumerate(test_pbar):
            src, trg, src_percentages, src_lengths, trg_lengths = data

            if USE_CUDA:
                src = src.cuda()
                trg = trg.cuda()

            start_time = time.time()
            batch_ids_hyps, batch_strs_hyps, batch_strs_gold = model.evaluate(
                src, src_lengths, trg, args, lm_rescoring=args.lm_rescoring, lm=lm, lm_weight=args.lm_weight, beam_search=args.beam_search, beam_width=args.beam_width, beam_nbest=args.beam_nbest, c_weight=args.c_weight, start_token=start_token, verbose=args.verbose)

            for x in range(len(batch_strs_gold)):
                hyp = post_process(batch_strs_hyps[x], vocab.special_token_list)
                gold = post_process(batch_strs_gold[x], vocab.special_token_list)

                wer = calculate_wer(hyp, gold)
                cer = calculate_cer(hyp.strip(), gold.strip())

                if args.verbose:
                    print("HYP",hyp)
                    print("GOLD:",gold)
                    print("CER:",cer)

                en_cer, zh_cer, num_en_char, num_zh_char = calculate_cer_en_zh(hyp, gold)
                total_en_cer += en_cer
                total_zh_cer += zh_cer
                total_en_char += num_en_char
                total_zh_char += num_zh_char
                total_hyp_char += len(hyp)

                total_wer += wer
                total_cer += cer
                total_word += len(gold.split(" "))
                total_char += len(gold)

            end_time = time.time()
            diff_time = end_time - start_time
            total_time += diff_time
            diff_time_per_word = total_time / total_word

            test_pbar.set_description("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format(
                total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char))
            print("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format(
                total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char), flush=True)
コード例 #2
0
    def train_one_batch(self, model, vocab, src, trg, src_percentages,
                        src_lengths, trg_lengths, smoothing, loss_type):
        pred, gold, hyp = model(src, src_lengths, trg, verbose=False)
        strs_golds, strs_hyps = [], []

        for j in range(len(gold)):
            ut_gold = gold[j]
            strs_golds.append("".join(
                [vocab.id2label[int(x)] for x in ut_gold]))

        for j in range(len(hyp)):
            ut_hyp = hyp[j]
            strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp]))

        # handling the last batch
        seq_length = pred.size(1)
        sizes = src_percentages.mul_(int(seq_length)).int()

        loss, num_correct = calculate_metrics(pred,
                                              gold,
                                              vocab.PAD_ID,
                                              input_lengths=sizes,
                                              target_lengths=trg_lengths,
                                              smoothing=smoothing,
                                              loss_type=loss_type)

        if loss is None:
            print("loss is None")

        if loss.item() == float('Inf'):
            logging.info("Found infinity loss, masking")
            print("Found infinity loss, masking")
            loss = torch.where(loss != loss, torch.zeros_like(loss),
                               loss)  # NaN masking

        total_cer, total_wer, total_char, total_word = 0, 0, 0, 0
        for j in range(len(strs_hyps)):
            strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list)
            strs_golds[j] = post_process(strs_golds[j],
                                         vocab.special_token_list)
            cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                strs_golds[j].replace(' ', ''))
            wer = calculate_wer(strs_hyps[j], strs_golds[j])
            total_cer += cer
            total_wer += wer
            total_char += len(strs_golds[j].replace(' ', ''))
            total_word += len(strs_golds[j].split(" "))

        return loss, total_cer, total_char
コード例 #3
0
    def forward_one_batch(self,
                          model,
                          vocab,
                          src,
                          trg,
                          src_percentages,
                          src_lengths,
                          trg_lengths,
                          smoothing,
                          loss_type,
                          verbose=False,
                          discriminator=None,
                          accent_id=None,
                          multi_task=False):
        if discriminator is None:
            pred, gold, hyp = model(src, src_lengths, trg, verbose=False)
        else:
            enc_output = model.encode(src, src_lengths)
            accent_pred = discriminator(torch.sum(enc_output, dim=1))
            pred, gold, hyp = model.decode(enc_output, src_lengths, trg)
            if multi_task:
                # calculate multi
                disc_loss = calculate_multi_task(accent_pred, accent_id)
            else:
                # calculate discriminator loss and encoder loss
                disc_loss, enc_loss = calculate_adversarial(
                    accent_pred, accent_id)

        strs_golds, strs_hyps = [], []

        for j in range(len(gold)):
            ut_gold = gold[j]
            strs_golds.append("".join(
                [vocab.id2label[int(x)] for x in ut_gold]))

        for j in range(len(hyp)):
            ut_hyp = hyp[j]
            strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp]))

        # handling the last batch
        seq_length = pred.size(1)
        sizes = src_percentages.mul_(int(seq_length)).int()

        loss, _ = calculate_metrics(pred,
                                    gold,
                                    vocab.PAD_ID,
                                    input_lengths=sizes,
                                    target_lengths=trg_lengths,
                                    smoothing=smoothing,
                                    loss_type=loss_type)

        if loss is None:
            print("loss is None")

        if loss.item() == float('Inf'):
            logging.info("Found infinity loss, masking")
            print("Found infinity loss, masking")
            loss = torch.where(loss != loss, torch.zeros_like(loss),
                               loss)  # NaN masking

        # if verbose:
        #     print(">PRED:", strs_hyps)
        #     print(">GOLD:", strs_golds)

        total_cer, total_wer, total_char, total_word = 0, 0, 0, 0
        for j in range(len(strs_hyps)):
            strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list)
            strs_golds[j] = post_process(strs_golds[j],
                                         vocab.special_token_list)
            cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                strs_golds[j].replace(' ', ''))
            wer = calculate_wer(strs_hyps[j], strs_golds[j])
            total_cer += cer
            total_wer += wer
            total_char += len(strs_golds[j].replace(' ', ''))
            total_word += len(strs_golds[j].split(" "))

        if verbose:
            print('Total CER', total_cer)
            print('Total char', total_char)

            print("PRED:", strs_hyps)
            print("GOLD:", strs_golds, flush=True)

        if discriminator is None:
            return loss, total_cer, total_char
        else:
            if multi_task:
                return loss, total_cer, total_char, disc_loss
            else:
                return loss, total_cer, total_char, disc_loss, enc_loss