def evaluate(model, vocab, test_loader, args, lm=None, start_token=-1): """ Evaluation args: model: Model object test_loader: DataLoader object """ model.eval() total_word, total_char, total_cer, total_wer = 0, 0, 0, 0 total_en_cer, total_zh_cer, total_en_char, total_zh_char = 0, 0, 0, 0 total_hyp_char = 0 total_time = 0 with torch.no_grad(): test_pbar = tqdm(iter(test_loader), leave=False, total=len(test_loader)) for i, (data) in enumerate(test_pbar): src, trg, src_percentages, src_lengths, trg_lengths = data if USE_CUDA: src = src.cuda() trg = trg.cuda() start_time = time.time() batch_ids_hyps, batch_strs_hyps, batch_strs_gold = model.evaluate( src, src_lengths, trg, args, lm_rescoring=args.lm_rescoring, lm=lm, lm_weight=args.lm_weight, beam_search=args.beam_search, beam_width=args.beam_width, beam_nbest=args.beam_nbest, c_weight=args.c_weight, start_token=start_token, verbose=args.verbose) for x in range(len(batch_strs_gold)): hyp = post_process(batch_strs_hyps[x], vocab.special_token_list) gold = post_process(batch_strs_gold[x], vocab.special_token_list) wer = calculate_wer(hyp, gold) cer = calculate_cer(hyp.strip(), gold.strip()) if args.verbose: print("HYP",hyp) print("GOLD:",gold) print("CER:",cer) en_cer, zh_cer, num_en_char, num_zh_char = calculate_cer_en_zh(hyp, gold) total_en_cer += en_cer total_zh_cer += zh_cer total_en_char += num_en_char total_zh_char += num_zh_char total_hyp_char += len(hyp) total_wer += wer total_cer += cer total_word += len(gold.split(" ")) total_char += len(gold) end_time = time.time() diff_time = end_time - start_time total_time += diff_time diff_time_per_word = total_time / total_word test_pbar.set_description("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format( total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char)) print("TEST CER:{:.2f}% WER:{:.2f}% CER_EN:{:.2f}% CER_ZH:{:.2f}% TOTAL_TIME:{:.7f} TOTAL HYP CHAR:{:.2f}".format( total_cer*100/total_char, total_wer*100/total_word, total_en_cer*100/max(1, total_en_char), total_zh_cer*100/max(1, total_zh_char), total_time, total_hyp_char), flush=True)
def evaluate(model, test_loader, lm=None): """ Evaluation args: model: Model object test_loader: DataLoader object """ model.eval() total_word, total_char, total_cer, total_wer = 0, 0, 0, 0 with torch.no_grad(): test_pbar = tqdm(iter(test_loader), leave=True, total=len(test_loader)) for i, (data) in enumerate(test_pbar): src, tgt, src_percentages, src_lengths, tgt_lengths = data if constant.USE_CUDA: src = src.cuda() tgt = tgt.cuda() batch_ids_hyps, batch_strs_hyps, batch_strs_gold = model.evaluate( src, src_lengths, tgt, beam_search=constant.args.beam_search, beam_width=constant.args.beam_width, beam_nbest=constant.args.beam_nbest, lm=lm, lm_rescoring=constant.args.lm_rescoring, lm_weight=constant.args.lm_weight, c_weight=constant.args.c_weight, verbose=constant.args.verbose) for x in range(len(batch_strs_gold)): hyp = batch_strs_hyps[x].replace( constant.EOS_CHAR, "").replace(constant.SOS_CHAR, "").replace(constant.PAD_CHAR, "") gold = batch_strs_gold[x].replace( constant.EOS_CHAR, "").replace(constant.SOS_CHAR, "").replace(constant.PAD_CHAR, "") wer = calculate_wer(hyp, gold) cer = calculate_cer(hyp.strip(), gold.strip()) total_wer += wer total_cer += cer total_word += len(gold.split(" ")) total_char += len(gold) test_pbar.set_description("TEST CER:{:.2f}% WER:{:.2f}%".format( total_cer * 100 / total_char, total_wer * 100 / total_word))
def train_one_batch(self, model, vocab, src, trg, src_percentages, src_lengths, trg_lengths, smoothing, loss_type): pred, gold, hyp = model(src, src_lengths, trg, verbose=False) strs_golds, strs_hyps = [], [] for j in range(len(gold)): ut_gold = gold[j] strs_golds.append("".join( [vocab.id2label[int(x)] for x in ut_gold])) for j in range(len(hyp)): ut_hyp = hyp[j] strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp])) # handling the last batch seq_length = pred.size(1) sizes = src_percentages.mul_(int(seq_length)).int() loss, num_correct = calculate_metrics(pred, gold, vocab.PAD_ID, input_lengths=sizes, target_lengths=trg_lengths, smoothing=smoothing, loss_type=loss_type) if loss is None: print("loss is None") if loss.item() == float('Inf'): logging.info("Found infinity loss, masking") print("Found infinity loss, masking") loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking total_cer, total_wer, total_char, total_word = 0, 0, 0, 0 for j in range(len(strs_hyps)): strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list) strs_golds[j] = post_process(strs_golds[j], vocab.special_token_list) cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_golds[j].replace(' ', '')) wer = calculate_wer(strs_hyps[j], strs_golds[j]) total_cer += cer total_wer += wer total_char += len(strs_golds[j].replace(' ', '')) total_word += len(strs_golds[j].split(" ")) return loss, total_cer, total_char
def forward_one_batch(self, model, vocab, src, trg, src_percentages, src_lengths, trg_lengths, smoothing, loss_type, verbose=False, discriminator=None, accent_id=None, multi_task=False): if discriminator is None: pred, gold, hyp = model(src, src_lengths, trg, verbose=False) else: enc_output = model.encode(src, src_lengths) accent_pred = discriminator(torch.sum(enc_output, dim=1)) pred, gold, hyp = model.decode(enc_output, src_lengths, trg) if multi_task: # calculate multi disc_loss = calculate_multi_task(accent_pred, accent_id) else: # calculate discriminator loss and encoder loss disc_loss, enc_loss = calculate_adversarial( accent_pred, accent_id) strs_golds, strs_hyps = [], [] for j in range(len(gold)): ut_gold = gold[j] strs_golds.append("".join( [vocab.id2label[int(x)] for x in ut_gold])) for j in range(len(hyp)): ut_hyp = hyp[j] strs_hyps.append("".join([vocab.id2label[int(x)] for x in ut_hyp])) # handling the last batch seq_length = pred.size(1) sizes = src_percentages.mul_(int(seq_length)).int() loss, _ = calculate_metrics(pred, gold, vocab.PAD_ID, input_lengths=sizes, target_lengths=trg_lengths, smoothing=smoothing, loss_type=loss_type) if loss is None: print("loss is None") if loss.item() == float('Inf'): logging.info("Found infinity loss, masking") print("Found infinity loss, masking") loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking # if verbose: # print(">PRED:", strs_hyps) # print(">GOLD:", strs_golds) total_cer, total_wer, total_char, total_word = 0, 0, 0, 0 for j in range(len(strs_hyps)): strs_hyps[j] = post_process(strs_hyps[j], vocab.special_token_list) strs_golds[j] = post_process(strs_golds[j], vocab.special_token_list) cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_golds[j].replace(' ', '')) wer = calculate_wer(strs_hyps[j], strs_golds[j]) total_cer += cer total_wer += wer total_char += len(strs_golds[j].replace(' ', '')) total_word += len(strs_golds[j].split(" ")) if verbose: print('Total CER', total_cer) print('Total char', total_char) print("PRED:", strs_hyps) print("GOLD:", strs_golds, flush=True) if discriminator is None: return loss, total_cer, total_char else: if multi_task: return loss, total_cer, total_char, disc_loss else: return loss, total_cer, total_char, disc_loss, enc_loss
def train(self, model, train_loader, train_sampler, valid_loader_list, opt, loss_type, start_epoch, num_epochs, label2id, id2label, last_metrics=None): """ Training args: model: Model object train_loader: DataLoader object of the training set valid_loader_list: a list of Validation DataLoader objects opt: Optimizer object start_epoch: start epoch (> 0 if you resume the process) num_epochs: last epoch last_metrics: (if resume) """ history = [] start_time = time.time() best_valid_loss = 1000000000 if last_metrics is None else last_metrics[ 'valid_loss'] smoothing = constant.args.label_smoothing logging.info("name " + constant.args.name) for epoch in range(start_epoch, num_epochs): sys.stdout.flush() total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0 start_iter = 0 logging.info("TRAIN") model.train() pbar = tqdm(iter(train_loader), leave=True, total=len(train_loader)) for i, (data) in enumerate(pbar, start=start_iter): src, tgt, src_percentages, src_lengths, tgt_lengths = data if constant.USE_CUDA: src = src.cuda() tgt = tgt.cuda() opt.zero_grad() pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False) try: # handle case for CTC strs_gold, strs_hyps = [], [] for ut_gold in gold_seq: str_gold = "" for x in ut_gold: if int(x) == constant.PAD_TOKEN: break str_gold = str_gold + id2label[int(x)] strs_gold.append(str_gold) for ut_hyp in hyp_seq: str_hyp = "" for x in ut_hyp: if int(x) == constant.PAD_TOKEN: break str_hyp = str_hyp + id2label[int(x)] strs_hyps.append(str_hyp) except Exception as e: print(e) logging.info("NaN predictions") continue seq_length = pred.size(1) sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False) loss, num_correct = calculate_metrics( pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type) if loss.item() == float('Inf'): logging.info("Found infinity loss, masking") loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking continue # if constant.args.verbose: # logging.info("GOLD", strs_gold) # logging.info("HYP", strs_hyps) for j in range(len(strs_hyps)): strs_hyps[j] = strs_hyps[j].replace( constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '') strs_gold[j] = strs_gold[j].replace( constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '') cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', '')) wer = calculate_wer(strs_hyps[j], strs_gold[j]) total_cer += cer total_wer += wer total_char += len(strs_gold[j].replace(' ', '')) total_word += len(strs_gold[j].split(" ")) loss.backward() if constant.args.clip: torch.nn.utils.clip_grad_norm_(model.parameters(), constant.args.max_norm) opt.step() total_loss += loss.item() non_pad_mask = gold.ne(constant.PAD_TOKEN) num_word = non_pad_mask.sum().item() pbar.set_description( "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}". format((epoch + 1), total_loss / (i + 1), total_cer * 100 / total_char, opt._rate)) logging.info( "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format( (epoch + 1), total_loss / (len(train_loader)), total_cer * 100 / total_char, opt._rate)) # evaluate print("") logging.info("VALID") model.eval() for ind in range(len(valid_loader_list)): valid_loader = valid_loader_list[ind] total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0 valid_pbar = tqdm(iter(valid_loader), leave=True, total=len(valid_loader)) for i, (data) in enumerate(valid_pbar): src, tgt, src_percentages, src_lengths, tgt_lengths = data if constant.USE_CUDA: src = src.cuda() tgt = tgt.cuda() pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False) seq_length = pred.size(1) sizes = Variable(src_percentages.mul_( int(seq_length)).int(), requires_grad=False) loss, num_correct = calculate_metrics( pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type) if loss.item() == float('Inf'): logging.info("Found infinity loss, masking") loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking continue try: # handle case for CTC strs_gold, strs_hyps = [], [] for ut_gold in gold_seq: str_gold = "" for x in ut_gold: if int(x) == constant.PAD_TOKEN: break str_gold = str_gold + id2label[int(x)] strs_gold.append(str_gold) for ut_hyp in hyp_seq: str_hyp = "" for x in ut_hyp: if int(x) == constant.PAD_TOKEN: break str_hyp = str_hyp + id2label[int(x)] strs_hyps.append(str_hyp) except Exception as e: print(e) logging.info("NaN predictions") continue for j in range(len(strs_hyps)): strs_hyps[j] = strs_hyps[j].replace( constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '') strs_gold[j] = strs_gold[j].replace( constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '') cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', '')) wer = calculate_wer(strs_hyps[j], strs_gold[j]) total_valid_cer += cer total_valid_wer += wer total_valid_char += len(strs_gold[j].replace(' ', '')) total_valid_word += len(strs_gold[j].split(" ")) total_valid_loss += loss.item() valid_pbar.set_description( "VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format( ind, total_valid_loss / (i + 1), total_valid_cer * 100 / total_valid_char)) logging.info("VALID SET {} LOSS:{:.4f} CER:{:.2f}%".format( ind, total_valid_loss / (len(valid_loader)), total_valid_cer * 100 / total_valid_char)) metrics = {} metrics["train_loss"] = total_loss / len(train_loader) metrics["valid_loss"] = total_valid_loss / (len(valid_loader)) metrics["train_cer"] = total_cer metrics["train_wer"] = total_wer metrics["valid_cer"] = total_valid_cer metrics["valid_wer"] = total_valid_wer metrics["history"] = history history.append(metrics) if epoch % constant.args.save_every == 0: save_model(model, (epoch + 1), opt, metrics, label2id, id2label, best_model=False) # save the best model if best_valid_loss > total_valid_loss / len(valid_loader): best_valid_loss = total_valid_loss / len(valid_loader) save_model(model, (epoch + 1), opt, metrics, label2id, id2label, best_model=True) if constant.args.shuffle: logging.info("SHUFFLE") print("SHUFFLE") train_sampler.shuffle(epoch)
def train(self, model, train_loader, train_sampler, valid_loaders, opt, loss_type, start_epoch, num_epochs, label2id, id2label, last_metrics=None, logger=None): """ Training args: model: Model object train_loader: DataLoader object of the training set valid_loaders: list of DataLoader object of the validation set opt: Optimizer object start_epoch: start epoch (> 0 if you resume the process) num_epochs: last epoch last_metrics: (if resume) """ if logger is not None: sys.out = logger start_time = time.time() best_valid_loss = 1000000000 if last_metrics is None else last_metrics[ 'valid_loss'] smoothing = constant.args.label_smoothing history = [] for epoch in range(start_epoch, num_epochs): sys.out.flush() total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0 start_iter = 0 print("TRAIN") model.train() pbar = tqdm(iter(train_loader), leave=True, total=len(train_loader)) for i, (data) in enumerate(pbar, start=start_iter): src, tgt, src_percentages, src_lengths, tgt_lengths = data if constant.USE_CUDA: src = src.cuda() tgt = tgt.cuda() opt.optimizer.zero_grad() pred, gold, hyp_seq, gold_seq = model( src, input_lengths=src_lengths, padded_target=tgt, verbose=constant.args.verbose) strs_gold = [ "".join([id2label[int(x)] for x in gold]) for gold in gold_seq ] strs_hyps = [ "".join([id2label[int(x)] for x in hyp]) for hyp in hyp_seq ] loss, num_correct = calculate_metrics( pred, gold, smoothing=smoothing, loss_type=loss_type, input_lengths=src_lengths, target_lengths=tgt_lengths) if constant.args.verbose: print("GOLD", strs_gold) print("HYP", strs_hyps) for j in range(len(strs_hyps)): cer = calculate_cer(strs_hyps[j], strs_gold[j]) wer = calculate_wer(strs_hyps[j], strs_gold[j]) total_cer += cer total_wer += wer total_char += len(strs_gold[j]) total_word += len(strs_gold[j].split(" ")) loss.backward() opt.optimizer.step() total_loss += loss.detach().item() non_pad_mask = gold.ne(constant.PAD_TOKEN) num_word = non_pad_mask.sum().item() pbar.set_description( "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%". format((epoch + 1), total_loss / (i + 1), total_cer * 100 / total_char, total_wer * 100 / total_word)) print( "(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%".format( (epoch + 1), total_loss / (len(train_loader)), total_cer * 100 / total_char, total_wer * 100 / total_word)) print("VALID") all_valid_loss = [] for valid_task_id in range(len(valid_loaders)): model.eval() sys.out.flush() valid_loader = valid_loaders[valid_task_id] total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0 valid_pbar = tqdm(iter(valid_loader), leave=True, total=len(valid_loader)) for i, (data) in enumerate(valid_pbar): src, tgt, src_percentages, src_lengths, tgt_lengths = data if constant.USE_CUDA: src = src.cuda() tgt = tgt.cuda() pred, gold, hyp_seq, gold_seq = model( src, input_lengths=src_lengths, padded_target=tgt, verbose=constant.args.verbose) loss, num_correct = calculate_metrics( pred, gold, smoothing=smoothing, loss_type=loss_type, input_lengths=src_lengths, target_lengths=tgt_lengths) strs_gold = [ "".join([id2label[int(x)] for x in gold]) for gold in gold_seq ] strs_hyps = [ "".join([id2label[int(x)] for x in hyp]) for hyp in hyp_seq ] for j in range(len(strs_hyps)): cer = calculate_cer(strs_hyps[j], strs_gold[j]) wer = calculate_wer(strs_hyps[j], strs_gold[j]) total_valid_cer += cer total_valid_wer += wer total_valid_char += len(strs_gold[j]) total_valid_word += len(strs_gold[j].split(" ")) total_valid_loss += loss.detach().item() valid_pbar.set_description( "(Epoch {}) TASK:{} VALID LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%" .format((epoch + 1), valid_task_id, total_valid_loss / (i + 1), total_valid_cer * 100 / total_valid_char, total_valid_wer * 100 / total_valid_word)) all_valid_loss.append(total_valid_loss / len(valid_pbar)) print( "(Epoch {}) TASK:{} VALID LOSS:{:.4f} CER:{:.2f}% WER:{:.2f}%" .format((epoch + 1), valid_task_id, total_valid_loss / (len(valid_loader)), total_valid_cer * 100 / total_valid_char, total_valid_wer * 100 / total_valid_word)) metrics = {} metrics["train_loss"] = total_loss / len(train_loader) metrics["valid_loss"] = np.mean(np.array(all_valid_loss)) metrics["valid_losses"] = all_valid_loss metrics["train_cer"] = total_cer metrics["train_wer"] = total_wer metrics["valid_cer"] = total_valid_cer metrics["valid_wer"] = total_valid_wer metrics["history"] = history history.append(metrics) if epoch % constant.args.save_every == 0: save_model(model, (epoch + 1), opt, metrics, label2id, id2label, best_model=False) # save the best model if best_valid_loss > total_valid_loss / len(valid_loader): best_valid_loss = total_valid_loss / len(valid_loader) save_model(model, (epoch + 1), opt, metrics, label2id, id2label, best_model=True) if constant.args.shuffle: print("SHUFFLE") train_sampler.shuffle(epoch)