def eval_word(models, dataset, decode_params, epoch, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate num_sub (int): the number of substitution errors num_ins (int): the number of insertion errors num_del (int): the number of deletion errors decode_dir (str): """ # Reset data counter dataset.reset() model = models[0] # TODO(hirofumi): ensemble decoding decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str( decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str( decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') wer = 0 num_sub, num_ins, num_del, = 0, 0, 0 num_words = 0 num_oov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO(hirofumi): fix this with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=True) ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): # Reference if dataset.is_test: text_ref = ys[b] else: text_ref = dataset.idx2word(ys[b]) # Hypothesis text_hyp = dataset.idx2word(best_hyps[b]) num_oov_total += text_hyp.count('<unk>') # Resolving UNK if decode_params['resolving_unk'] and '<unk>' in text_hyp: best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b + 1], batch['xs'], decode_params, exclude_eos=True) # task_index=1 text_hyp = resolve_unk( text_hyp, best_hyps_sub[0], aw[b], aw_sub[0], dataset.idx2char, diff_time_resolution=2**sum(model.subsample_list) // 2**sum( model. subsample_list[:model.encoder_num_layers_sub - 1])) text_hyp = text_hyp.replace('*', '') # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=text_ref.split(' '), hyp=text_hyp.split(' '), normalize=False) wer += wer_b num_sub += sub_b num_ins += ins_b num_del += del_b num_words += len(text_ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words num_sub /= num_words num_ins /= num_words num_del /= num_words return wer, num_sub, num_ins, num_del, os.path.join( model.save_path, decode_dir)
def eval_word(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a word-level model by WER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch recog_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) fine_grained (bool): calculate fine-grained WER distributions based on input lengths oracle (bool): calculate oracle WER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): total number of OOV """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(recog_params.get('recog_beam_width')) recog_dir += '_lp' + str(recog_params.get('recog_length_penalty')) recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty')) recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \ str(recog_params.get('recog_max_len_ratio')) recog_dir += '_lm' + str(recog_params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 wer_dist = {} # calculate WER distribution based on input lengths n_oov_total = 0 wer_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(recog_params.get('recog_batch_size')) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or recog_params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=speakers, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] n_oov_total += nbest_hyps[0].count('<unk>') # Resolving UNK if recog_params.get( 'recog_resolving_unk') and '<unk>' in nbest_hyps[0]: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataloader.idx2token[1], exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=speakers, task='ys_sub1') # TODO(hirofumi): support ys_sub2 assert not streaming nbest_hyps[0] = resolve_unk( nbest_hyps[0], best_hyps_id_char[0], aws[b], aw_char[0], dataloader.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % nbest_hyps[0]) nbest_hyps[0] = nbest_hyps[0].replace('*', '') # Compute CER ref_char = ref hyp_char = nbest_hyps[0] if dataloader.corpus == 'csj': ref_char = ref_char.replace(' ', '') hyp_char = hyp_char.replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char)) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute oracle WER if oracle and len(nbest_hyps) > 1: wers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(wers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) wer_oracle += wers_b[oracle_idx] # NOTE: OOV resolution is not considered if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [err_b / 100] else: wer_dist[xlen_bin] = [err_b / 100] n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if recog_params.get('recog_beam_width') > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('OOV (total): %d' % (n_oov_total)) if oracle: wer_oracle /= n_word oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle WER (%s): %.2f %%' % (dataloader.set, wer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataloader.set, sum(wers) / len(wers), len_bin)) return wer, cer, n_oov_total
def eval_word(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): totol number of OOV """ # Reset data counter dataset.reset(recog_params['recog_batch_size']) if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_oov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[0](best_hyps_id[b]) n_oov_total += hyp.count('<unk>') # Resolving UNK if recog_params['recog_resolving_unk'] and '<unk>' in hyp: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataset.idx2token[1] if progressbar else None, exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], task='ys_sub1') # TODO(hirofumi): support ys_sub2 and ys_sub3 assert not streaming hyp = resolve_unk( hyp, best_hyps_id_char[0], aws[b], aw_char[0], dataset.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % hyp) hyp = hyp.replace('*', '') # Compute CER ref_char = ref hyp_char = hyp if dataset.corpus == 'csj': ref_char = ref.replace(' ', '') hyp_char = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.debug('OOV (total): %d' % (n_oov_total)) return wer, cer, n_oov_total
def eval_word(models, dataset, decode_params, epoch, decode_dir=None, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): decode_dir (str): progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate nsub (int): the number of substitution errors nins (int): the number of insertion errors ndel (int): the number of deletion errors noov_total (int): """ # Reset data counter dataset.reset() model = models[0] # TODO(hirofumi): ensemble decoding if decode_dir is None: decode_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str( decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn') wer = 0 nsub, nins, ndel = 0, 0, 0 nword = 0 noov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO(hirofumi): fix this with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aws, perm_ids = model.decode(batch['xs'], decode_params, exclude_eos=True) ys = [batch['text'][i] for i in perm_ids] for b in six.moves.range(len(batch['xs'])): ref = ys[b] hyp = dataset.id2word(best_hyps[b]) noov_total += hyp.count('<unk>') # Resolving UNK if decode_params['resolving_unk'] and '<unk>' in hyp: best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b + 1], decode_params, exclude_eos=True) # task_index=1 hyp = resolve_unk( hyp, best_hyps_sub[0], aws[b], aw_sub[0], dataset.id2char, diff_time_resolution=2**sum(model.subsample) // 2**sum(model.subsample[:model.enc_nlayers_sub - 1])) hyp = hyp.replace('*', '') # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) # logger.info('Ref: %s' % ref.lower()) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b nsub += sub_b nins += ins_b ndel += del_b nword += len(ref.split(' ')) # logger.info('WER: %d%%' % (wer_b / len(ref.split(' ')))) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= nword nsub /= nword nins /= nword ndel /= nword return wer, nsub, nins, ndel, noov_total