Ejemplo n.º 1
0
def evaluate(model, vocab, test_loader, args, lm=None, start_token=-1):
    """
    Evaluation
    args:
        model: Model object
        test_loader: DataLoader object
    """
    model.eval()

    total_word, total_char, total_cer, total_wer = 0, 0, 0, 0
    total_en_cer, total_zh_cer, total_en_char, total_zh_char = 0, 0, 0, 0
    total_hyp_char = 0
    total_time = 0

    with torch.no_grad():
        test_pbar = tqdm(iter(test_loader), leave=False, total=len(test_loader))
        for i, (data) in enumerate(test_pbar):
            src, trg, src_percentages, src_lengths, trg_lengths = data

            if USE_CUDA:
                src = src.cuda()
                trg = trg.cuda()

            start_time = time.time()
            batch_ids_hyps, batch_strs_hyps, batch_strs_gold = model.evaluate(
                src, src_lengths, trg, args, lm_rescoring=args.lm_rescoring, lm=lm, lm_weight=args.lm_weight, beam_search=args.beam_search, beam_width=args.beam_width, beam_nbest=args.beam_nbest, c_weight=args.c_weight, start_token=start_token, verbose=args.verbose)

            for x in range(len(batch_strs_gold)):
                hyp = post_process(batch_strs_hyps[x], vocab.special_token_list)
                gold = post_process(batch_strs_gold[x], vocab.special_token_list)

                wer = calculate_wer(hyp, gold)
                cer = calculate_cer(hyp.strip(), gold.strip())

                if args.verbose:
                    print("HYP",hyp)
                    print("GOLD:",gold)
                    print("CER:",cer)

                en_cer, zh_cer, num_en_char, num_zh_char = calculate_cer_en_zh(hyp, gold)
                total_en_cer += en_cer
                total_zh_cer += zh_cer
                total_en_char += num_en_char
                total_zh_char += num_zh_char
                total_hyp_char += len(hyp)

                total_wer += wer
                total_cer += cer
                total_word += len(gold.split(" "))
                total_char += len(gold)

            end_time = time.time()
            diff_time = end_time - start_time
            total_time += diff_time
            diff_time_per_word = total_time / total_word

            test_pbar.set_description("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format(
                total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char))
            print("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format(
                total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char), flush=True)
Ejemplo n.º 2
0
def evaluate(model, test_loader, lm=None):
    """
    Evaluation
    args:
        model: Model object
        test_loader: DataLoader object
    """
    model.eval()

    total_word, total_char, total_cer, total_wer = 0, 0, 0, 0

    with torch.no_grad():
        test_pbar = tqdm(iter(test_loader), leave=True, total=len(test_loader))
        for i, (data) in enumerate(test_pbar):
            src, tgt, src_percentages, src_lengths, tgt_lengths = data

            if constant.USE_CUDA:
                src = src.cuda()
                tgt = tgt.cuda()

            batch_ids_hyps, batch_strs_hyps, batch_strs_gold = model.evaluate(
                src,
                src_lengths,
                tgt,
                beam_search=constant.args.beam_search,
                beam_width=constant.args.beam_width,
                beam_nbest=constant.args.beam_nbest,
                lm=lm,
                lm_rescoring=constant.args.lm_rescoring,
                lm_weight=constant.args.lm_weight,
                c_weight=constant.args.c_weight,
                verbose=constant.args.verbose)

            for x in range(len(batch_strs_gold)):
                hyp = batch_strs_hyps[x].replace(
                    constant.EOS_CHAR,
                    "").replace(constant.SOS_CHAR,
                                "").replace(constant.PAD_CHAR, "")
                gold = batch_strs_gold[x].replace(
                    constant.EOS_CHAR,
                    "").replace(constant.SOS_CHAR,
                                "").replace(constant.PAD_CHAR, "")

                wer = calculate_wer(hyp, gold)
                cer = calculate_cer(hyp.strip(), gold.strip())
                total_wer += wer
                total_cer += cer
                total_word += len(gold.split(" "))
                total_char += len(gold)

            test_pbar.set_description("TEST CER:{:.2f}% WER:{:.2f}%".format(
                total_cer * 100 / total_char, total_wer * 100 / total_word))
Ejemplo n.º 3
0
    def train_one_batch(self, model, vocab, src, trg, src_percentages,
                        src_lengths, trg_lengths, smoothing, loss_type):
        pred, gold, hyp = model(src, src_lengths, trg, verbose=False)
        strs_golds, strs_hyps = [], []

        for j in range(len(gold)):
            ut_gold = gold[j]
            strs_golds.append("".join(
                [vocab.id2label[int(x)] for x in ut_gold]))

        for j in range(len(hyp)):
            ut_hyp = hyp[j]
            strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp]))

        # handling the last batch
        seq_length = pred.size(1)
        sizes = src_percentages.mul_(int(seq_length)).int()

        loss, num_correct = calculate_metrics(pred,
                                              gold,
                                              vocab.PAD_ID,
                                              input_lengths=sizes,
                                              target_lengths=trg_lengths,
                                              smoothing=smoothing,
                                              loss_type=loss_type)

        if loss is None:
            print("loss is None")

        if loss.item() == float('Inf'):
            logging.info("Found infinity loss, masking")
            print("Found infinity loss, masking")
            loss = torch.where(loss != loss, torch.zeros_like(loss),
                               loss)  # NaN masking

        total_cer, total_wer, total_char, total_word = 0, 0, 0, 0
        for j in range(len(strs_hyps)):
            strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list)
            strs_golds[j] = post_process(strs_golds[j],
                                         vocab.special_token_list)
            cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                strs_golds[j].replace(' ', ''))
            wer = calculate_wer(strs_hyps[j], strs_golds[j])
            total_cer += cer
            total_wer += wer
            total_char += len(strs_golds[j].replace(' ', ''))
            total_word += len(strs_golds[j].split(" "))

        return loss, total_cer, total_char
    def forward_one_batch(self,
                          model,
                          vocab,
                          src,
                          trg,
                          src_percentages,
                          src_lengths,
                          trg_lengths,
                          smoothing,
                          loss_type,
                          verbose=False,
                          discriminator=None,
                          accent_id=None,
                          multi_task=False):
        if discriminator is None:
            pred, gold, hyp = model(src, src_lengths, trg, verbose=False)
        else:
            enc_output = model.encode(src, src_lengths)
            accent_pred = discriminator(torch.sum(enc_output, dim=1))
            pred, gold, hyp = model.decode(enc_output, src_lengths, trg)
            if multi_task:
                # calculate multi
                disc_loss = calculate_multi_task(accent_pred, accent_id)
            else:
                # calculate discriminator loss and encoder loss
                disc_loss, enc_loss = calculate_adversarial(
                    accent_pred, accent_id)

        strs_golds, strs_hyps = [], []

        for j in range(len(gold)):
            ut_gold = gold[j]
            strs_golds.append("".join(
                [vocab.id2label[int(x)] for x in ut_gold]))

        for j in range(len(hyp)):
            ut_hyp = hyp[j]
            strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp]))

        # handling the last batch
        seq_length = pred.size(1)
        sizes = src_percentages.mul_(int(seq_length)).int()

        loss, _ = calculate_metrics(pred,
                                    gold,
                                    vocab.PAD_ID,
                                    input_lengths=sizes,
                                    target_lengths=trg_lengths,
                                    smoothing=smoothing,
                                    loss_type=loss_type)

        if loss is None:
            print("loss is None")

        if loss.item() == float('Inf'):
            logging.info("Found infinity loss, masking")
            print("Found infinity loss, masking")
            loss = torch.where(loss != loss, torch.zeros_like(loss),
                               loss)  # NaN masking

        # if verbose:
        #     print(">PRED:", strs_hyps)
        #     print(">GOLD:", strs_golds)

        total_cer, total_wer, total_char, total_word = 0, 0, 0, 0
        for j in range(len(strs_hyps)):
            strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list)
            strs_golds[j] = post_process(strs_golds[j],
                                         vocab.special_token_list)
            cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                strs_golds[j].replace(' ', ''))
            wer = calculate_wer(strs_hyps[j], strs_golds[j])
            total_cer += cer
            total_wer += wer
            total_char += len(strs_golds[j].replace(' ', ''))
            total_word += len(strs_golds[j].split(" "))

        if verbose:
            print('Total CER', total_cer)
            print('Total char', total_char)

            print("PRED:", strs_hyps)
            print("GOLD:", strs_golds, flush=True)

        if discriminator is None:
            return loss, total_cer, total_char
        else:
            if multi_task:
                return loss, total_cer, total_char, disc_loss
            else:
                return loss, total_cer, total_char, disc_loss, enc_loss
Ejemplo n.º 5
0
    def train(self,
              model,
              train_loader,
              train_sampler,
              valid_loader_list,
              opt,
              loss_type,
              start_epoch,
              num_epochs,
              label2id,
              id2label,
              last_metrics=None):
        """
        Training
        args:
            model: Model object
            train_loader: DataLoader object of the training set
            valid_loader_list: a list of Validation DataLoader objects
            opt: Optimizer object
            start_epoch: start epoch (> 0 if you resume the process)
            num_epochs: last epoch
            last_metrics: (if resume)
        """
        history = []
        start_time = time.time()
        best_valid_loss = 1000000000 if last_metrics is None else last_metrics[
            'valid_loss']
        smoothing = constant.args.label_smoothing

        logging.info("name " + constant.args.name)

        for epoch in range(start_epoch, num_epochs):
            sys.stdout.flush()
            total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0

            start_iter = 0

            logging.info("TRAIN")
            model.train()
            pbar = tqdm(iter(train_loader),
                        leave=True,
                        total=len(train_loader))
            for i, (data) in enumerate(pbar, start=start_iter):
                src, tgt, src_percentages, src_lengths, tgt_lengths = data

                if constant.USE_CUDA:
                    src = src.cuda()
                    tgt = tgt.cuda()

                opt.zero_grad()

                pred, gold, hyp_seq, gold_seq = model(src,
                                                      src_lengths,
                                                      tgt,
                                                      verbose=False)

                try:  # handle case for CTC
                    strs_gold, strs_hyps = [], []
                    for ut_gold in gold_seq:
                        str_gold = ""
                        for x in ut_gold:
                            if int(x) == constant.PAD_TOKEN:
                                break
                            str_gold = str_gold + id2label[int(x)]
                        strs_gold.append(str_gold)
                    for ut_hyp in hyp_seq:
                        str_hyp = ""
                        for x in ut_hyp:
                            if int(x) == constant.PAD_TOKEN:
                                break
                            str_hyp = str_hyp + id2label[int(x)]
                        strs_hyps.append(str_hyp)
                except Exception as e:
                    print(e)
                    logging.info("NaN predictions")
                    continue

                seq_length = pred.size(1)
                sizes = Variable(src_percentages.mul_(int(seq_length)).int(),
                                 requires_grad=False)

                loss, num_correct = calculate_metrics(
                    pred,
                    gold,
                    input_lengths=sizes,
                    target_lengths=tgt_lengths,
                    smoothing=smoothing,
                    loss_type=loss_type)

                if loss.item() == float('Inf'):
                    logging.info("Found infinity loss, masking")
                    loss = torch.where(loss != loss, torch.zeros_like(loss),
                                       loss)  # NaN masking
                    continue

                # if constant.args.verbose:
                #     logging.info("GOLD", strs_gold)
                #     logging.info("HYP", strs_hyps)

                for j in range(len(strs_hyps)):
                    strs_hyps[j] = strs_hyps[j].replace(
                        constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
                    strs_gold[j] = strs_gold[j].replace(
                        constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
                    cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                        strs_gold[j].replace(' ', ''))
                    wer = calculate_wer(strs_hyps[j], strs_gold[j])
                    total_cer += cer
                    total_wer += wer
                    total_char += len(strs_gold[j].replace(' ', ''))
                    total_word += len(strs_gold[j].split(" "))

                loss.backward()

                if constant.args.clip:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   constant.args.max_norm)

                opt.step()

                total_loss += loss.item()
                non_pad_mask = gold.ne(constant.PAD_TOKEN)
                num_word = non_pad_mask.sum().item()

                pbar.set_description(
                    "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".
                    format((epoch + 1), total_loss / (i + 1),
                           total_cer * 100 / total_char, opt._rate))
            logging.info(
                "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format(
                    (epoch + 1), total_loss / (len(train_loader)),
                    total_cer * 100 / total_char, opt._rate))

            # evaluate
            print("")
            logging.info("VALID")
            model.eval()

            for ind in range(len(valid_loader_list)):
                valid_loader = valid_loader_list[ind]

                total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
                valid_pbar = tqdm(iter(valid_loader),
                                  leave=True,
                                  total=len(valid_loader))
                for i, (data) in enumerate(valid_pbar):
                    src, tgt, src_percentages, src_lengths, tgt_lengths = data

                    if constant.USE_CUDA:
                        src = src.cuda()
                        tgt = tgt.cuda()

                    pred, gold, hyp_seq, gold_seq = model(src,
                                                          src_lengths,
                                                          tgt,
                                                          verbose=False)

                    seq_length = pred.size(1)
                    sizes = Variable(src_percentages.mul_(
                        int(seq_length)).int(),
                                     requires_grad=False)

                    loss, num_correct = calculate_metrics(
                        pred,
                        gold,
                        input_lengths=sizes,
                        target_lengths=tgt_lengths,
                        smoothing=smoothing,
                        loss_type=loss_type)

                    if loss.item() == float('Inf'):
                        logging.info("Found infinity loss, masking")
                        loss = torch.where(loss != loss,
                                           torch.zeros_like(loss),
                                           loss)  # NaN masking
                        continue

                    try:  # handle case for CTC
                        strs_gold, strs_hyps = [], []
                        for ut_gold in gold_seq:
                            str_gold = ""
                            for x in ut_gold:
                                if int(x) == constant.PAD_TOKEN:
                                    break
                                str_gold = str_gold + id2label[int(x)]
                            strs_gold.append(str_gold)
                        for ut_hyp in hyp_seq:
                            str_hyp = ""
                            for x in ut_hyp:
                                if int(x) == constant.PAD_TOKEN:
                                    break
                                str_hyp = str_hyp + id2label[int(x)]
                            strs_hyps.append(str_hyp)
                    except Exception as e:
                        print(e)
                        logging.info("NaN predictions")
                        continue

                    for j in range(len(strs_hyps)):
                        strs_hyps[j] = strs_hyps[j].replace(
                            constant.SOS_CHAR,
                            '').replace(constant.EOS_CHAR, '')
                        strs_gold[j] = strs_gold[j].replace(
                            constant.SOS_CHAR,
                            '').replace(constant.EOS_CHAR, '')
                        cer = calculate_cer(strs_hyps[j].replace(' ', ''),
                                            strs_gold[j].replace(' ', ''))
                        wer = calculate_wer(strs_hyps[j], strs_gold[j])
                        total_valid_cer += cer
                        total_valid_wer += wer
                        total_valid_char += len(strs_gold[j].replace(' ', ''))
                        total_valid_word += len(strs_gold[j].split(" "))

                    total_valid_loss += loss.item()
                    valid_pbar.set_description(
                        "VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(
                            ind, total_valid_loss / (i + 1),
                            total_valid_cer * 100 / total_valid_char))
                logging.info("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format(
                    ind, total_valid_loss / (len(valid_loader)),
                    total_valid_cer * 100 / total_valid_char))

            metrics = {}
            metrics["train_loss"] = total_loss / len(train_loader)
            metrics["valid_loss"] = total_valid_loss / (len(valid_loader))
            metrics["train_cer"] = total_cer
            metrics["train_wer"] = total_wer
            metrics["valid_cer"] = total_valid_cer
            metrics["valid_wer"] = total_valid_wer
            metrics["history"] = history
            history.append(metrics)

            if epoch % constant.args.save_every == 0:
                save_model(model, (epoch + 1),
                           opt,
                           metrics,
                           label2id,
                           id2label,
                           best_model=False)

            # save the best model
            if best_valid_loss > total_valid_loss / len(valid_loader):
                best_valid_loss = total_valid_loss / len(valid_loader)
                save_model(model, (epoch + 1),
                           opt,
                           metrics,
                           label2id,
                           id2label,
                           best_model=True)

            if constant.args.shuffle:
                logging.info("SHUFFLE")
                print("SHUFFLE")
                train_sampler.shuffle(epoch)
Ejemplo n.º 6
0
    def train(self,
              model,
              train_loader,
              train_sampler,
              valid_loaders,
              opt,
              loss_type,
              start_epoch,
              num_epochs,
              label2id,
              id2label,
              last_metrics=None,
              logger=None):
        """
        Training
        args:
            model: Model object
            train_loader: DataLoader object of the training set
            valid_loaders: list of DataLoader object of the validation set
            opt: Optimizer object
            start_epoch: start epoch (> 0 if you resume the process)
            num_epochs: last epoch
            last_metrics: (if resume)
        """
        if logger is not None:
            sys.out = logger

        start_time = time.time()
        best_valid_loss = 1000000000 if last_metrics is None else last_metrics[
            'valid_loss']
        smoothing = constant.args.label_smoothing

        history = []

        for epoch in range(start_epoch, num_epochs):
            sys.out.flush()
            total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0
            start_iter = 0

            print("TRAIN")
            model.train()
            pbar = tqdm(iter(train_loader),
                        leave=True,
                        total=len(train_loader))
            for i, (data) in enumerate(pbar, start=start_iter):
                src, tgt, src_percentages, src_lengths, tgt_lengths = data

                if constant.USE_CUDA:
                    src = src.cuda()
                    tgt = tgt.cuda()

                opt.optimizer.zero_grad()

                pred, gold, hyp_seq, gold_seq = model(
                    src,
                    input_lengths=src_lengths,
                    padded_target=tgt,
                    verbose=constant.args.verbose)

                strs_gold = [
                    "".join([id2label[int(x)] for x in gold])
                    for gold in gold_seq
                ]
                strs_hyps = [
                    "".join([id2label[int(x)] for x in hyp]) for hyp in hyp_seq
                ]

                loss, num_correct = calculate_metrics(
                    pred,
                    gold,
                    smoothing=smoothing,
                    loss_type=loss_type,
                    input_lengths=src_lengths,
                    target_lengths=tgt_lengths)

                if constant.args.verbose:
                    print("GOLD", strs_gold)
                    print("HYP", strs_hyps)

                for j in range(len(strs_hyps)):
                    cer = calculate_cer(strs_hyps[j], strs_gold[j])
                    wer = calculate_wer(strs_hyps[j], strs_gold[j])
                    total_cer += cer
                    total_wer += wer
                    total_char += len(strs_gold[j])
                    total_word += len(strs_gold[j].split(" "))

                loss.backward()
                opt.optimizer.step()

                total_loss += loss.detach().item()
                non_pad_mask = gold.ne(constant.PAD_TOKEN)
                num_word = non_pad_mask.sum().item()

                pbar.set_description(
                    "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%".
                    format((epoch + 1), total_loss / (i + 1),
                           total_cer * 100 / total_char,
                           total_wer * 100 / total_word))
            print(
                "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%".format(
                    (epoch + 1), total_loss / (len(train_loader)),
                    total_cer * 100 / total_char,
                    total_wer * 100 / total_word))

            print("VALID")
            all_valid_loss = []
            for valid_task_id in range(len(valid_loaders)):
                model.eval()
                sys.out.flush()

                valid_loader = valid_loaders[valid_task_id]

                total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
                valid_pbar = tqdm(iter(valid_loader),
                                  leave=True,
                                  total=len(valid_loader))
                for i, (data) in enumerate(valid_pbar):
                    src, tgt, src_percentages, src_lengths, tgt_lengths = data

                    if constant.USE_CUDA:
                        src = src.cuda()
                        tgt = tgt.cuda()

                    pred, gold, hyp_seq, gold_seq = model(
                        src,
                        input_lengths=src_lengths,
                        padded_target=tgt,
                        verbose=constant.args.verbose)
                    loss, num_correct = calculate_metrics(
                        pred,
                        gold,
                        smoothing=smoothing,
                        loss_type=loss_type,
                        input_lengths=src_lengths,
                        target_lengths=tgt_lengths)

                    strs_gold = [
                        "".join([id2label[int(x)] for x in gold])
                        for gold in gold_seq
                    ]
                    strs_hyps = [
                        "".join([id2label[int(x)] for x in hyp])
                        for hyp in hyp_seq
                    ]

                    for j in range(len(strs_hyps)):
                        cer = calculate_cer(strs_hyps[j], strs_gold[j])
                        wer = calculate_wer(strs_hyps[j], strs_gold[j])
                        total_valid_cer += cer
                        total_valid_wer += wer
                        total_valid_char += len(strs_gold[j])
                        total_valid_word += len(strs_gold[j].split(" "))

                    total_valid_loss += loss.detach().item()
                    valid_pbar.set_description(
                        "(Epoch {}) TASK:{} VALID LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%"
                        .format((epoch + 1), valid_task_id,
                                total_valid_loss / (i + 1),
                                total_valid_cer * 100 / total_valid_char,
                                total_valid_wer * 100 / total_valid_word))
                all_valid_loss.append(total_valid_loss / len(valid_pbar))
                print(
                    "(Epoch {}) TASK:{} VALID LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%"
                    .format((epoch + 1), valid_task_id,
                            total_valid_loss / (len(valid_loader)),
                            total_valid_cer * 100 / total_valid_char,
                            total_valid_wer * 100 / total_valid_word))

            metrics = {}
            metrics["train_loss"] = total_loss / len(train_loader)
            metrics["valid_loss"] = np.mean(np.array(all_valid_loss))
            metrics["valid_losses"] = all_valid_loss
            metrics["train_cer"] = total_cer
            metrics["train_wer"] = total_wer
            metrics["valid_cer"] = total_valid_cer
            metrics["valid_wer"] = total_valid_wer
            metrics["history"] = history
            history.append(metrics)

            if epoch % constant.args.save_every == 0:
                save_model(model, (epoch + 1),
                           opt,
                           metrics,
                           label2id,
                           id2label,
                           best_model=False)

            # save the best model
            if best_valid_loss > total_valid_loss / len(valid_loader):
                best_valid_loss = total_valid_loss / len(valid_loader)
                save_model(model, (epoch + 1),
                           opt,
                           metrics,
                           label2id,
                           id2label,
                           best_model=True)

            if constant.args.shuffle:
                print("SHUFFLE")
                train_sampler.shuffle(epoch)