def eval_word(models, dataset, decode_params, epoch, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate num_sub (int): the number of substitution errors num_ins (int): the number of insertion errors num_del (int): the number of deletion errors decode_dir (str): """ # Reset data counter dataset.reset() model = models[0] # TODO(hirofumi): ensemble decoding decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str( decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str( decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') wer = 0 num_sub, num_ins, num_del, = 0, 0, 0 num_words = 0 num_oov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO(hirofumi): fix this with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=True) ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): # Reference if dataset.is_test: text_ref = ys[b] else: text_ref = dataset.idx2word(ys[b]) # Hypothesis text_hyp = dataset.idx2word(best_hyps[b]) num_oov_total += text_hyp.count('<unk>') # Resolving UNK if decode_params['resolving_unk'] and '<unk>' in text_hyp: best_hyps_sub, aw_sub, _ = model.decode(batch['xs'][b:b + 1], batch['xs'], decode_params, exclude_eos=True) # task_index=1 text_hyp = resolve_unk( text_hyp, best_hyps_sub[0], aw[b], aw_sub[0], dataset.idx2char, diff_time_resolution=2**sum(model.subsample_list) // 2**sum( model. subsample_list[:model.encoder_num_layers_sub - 1])) text_hyp = text_hyp.replace('*', '') # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=text_ref.split(' '), hyp=text_hyp.split(' '), normalize=False) wer += wer_b num_sub += sub_b num_ins += ins_b num_del += del_b num_words += len(text_ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words num_sub /= num_words num_ins /= num_words num_del /= num_words return wer, num_sub, num_ins, num_del, os.path.join( model.save_path, decode_dir)
def eval_char(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, task_idx=0): """Evaluate the character-level model by WER & CER. Args: models (list): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar task_idx (int): the index of the target task in interest 0: main task 1: sub task 2: sub sub task Returns: wer (float): Word error rate cer (float): Character error rate """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 # Reset data counter dataloader.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataloader)) if task_idx == 0: task = 'ys' elif task_idx == 1: task = 'ys_sub1' elif task_idx == 2: task = 'ys_sub2' elif task_idx == 3: task = 'ys_sub3' with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataloader.next(recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[task_idx] if progressbar else None, exclude_eos=True, refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'], task=task, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataloader.idx2token[task_idx](best_hyps_id[b]) # Truncate the first and last spaces for the char_space unit if len(hyp) > 0 and hyp[0] == ' ': hyp = hyp[1:] if len(hyp) > 0 and hyp[-1] == ' ': hyp = hyp[:-1] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # NOTE: sentence error rate for Chinese # Compute CER if dataloader.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if models[0].streamable(): n_streamable += 1 else: last_success_frame_ratio += models[0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataloader.reset() if not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt logger.debug('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio)) return wer, cer
def eval_word(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False): """Evaluate the word-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): totol number of OOV """ # Reset data counter dataset.reset(recog_params['recog_batch_size']) if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_oov_total = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[0](best_hyps_id[b]) n_oov_total += hyp.count('<unk>') # Resolving UNK if recog_params['recog_resolving_unk'] and '<unk>' in hyp: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataset.idx2token[1] if progressbar else None, exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], task='ys_sub1') # TODO(hirofumi): support ys_sub2 and ys_sub3 assert not streaming hyp = resolve_unk( hyp, best_hyps_id_char[0], aws[b], aw_char[0], dataset.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % hyp) hyp = hyp.replace('*', '') # Compute CER ref_char = ref hyp_char = hyp if dataset.corpus == 'csj': ref_char = ref.replace(' ', '') hyp_char = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.debug('OOV (total): %d' % (n_oov_total)) return wer, cer, n_oov_total
def eval_wordpiece(models, dataset, recog_params, epoch, recog_dir=None, progressbar=False): """Evaluate the wordpiece-level model by WER. Args: models (list): models to evaluate dataset: An instance of a `Dataset' class recog_params (recog_dict): epoch (int): recog_dir (str): progressbar (bool): visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate """ # Reset data counter dataset.reset() if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) best_hyps_id, _, perm_id, _ = models[0].decode( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], ensemble_models=models[1:] if len(models) > 1 else []) ys = [batch['text'][i] for i in perm_id] for b in range(len(batch['xs'])): ref = ys[b] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn utt_id = str(batch['utt_ids'][b]) speaker = str(batch['speakers'][b]).replace('-', '_') f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 150) # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute CER if dataset.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.info('WER (%s): %.2f %%' % (dataset.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataset.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) return wer, cer
def eval_phone(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a phone-level model by PER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar oracle (bool): calculate oracle PER fine_grained (bool): calculate fine-grained PER distributions based on input lengths teacher_force (bool): conduct decoding in teacher-forcing mode Returns: per (float): Phone error rate """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') per = 0 n_sub, n_ins, n_del = 0, 0, 0 n_phone = 0 per_dist = {} # calculate PER distribution based on input lengths per_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataloader.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_block_sync']: nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True)[0] else: nbest_hyps_id = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else [])[0] for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if not streaming: # Compute PER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) per += err_b n_sub += sub_b n_ins += ins_b n_del += del_b n_phone += len(ref.split(' ')) # Compute oracle PER if oracle and len(nbest_hyps) > 1: pers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(pers_b)) if oracle_idx == 0: n_oracle_hit += 1 per_oracle += pers_b[oracle_idx] n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataloader.reset() if not streaming: per /= n_phone n_sub /= n_phone n_ins /= n_phone n_del /= n_phone if recog_params['recog_beam_width'] > 1: logger.info('PER (%s): %.2f %%' % (dataloader.set, per)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del)) if oracle: per_oracle /= n_phone oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle PER (%s): %.2f %%' % (dataloader.set, per_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]): logger.info(' PER (%s): %.2f %% (%d)' % (dataloader.set, sum(pers) / len(pers), len_bin)) return per
def eval_word(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a word-level model by WER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch recog_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) fine_grained (bool): calculate fine-grained WER distributions based on input lengths oracle (bool): calculate oracle WER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate n_oov_total (int): total number of OOV """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(recog_params.get('recog_beam_width')) recog_dir += '_lp' + str(recog_params.get('recog_length_penalty')) recog_dir += '_cp' + str(recog_params.get('recog_coverage_penalty')) recog_dir += '_' + str(recog_params.get('recog_min_len_ratio')) + '_' + \ str(recog_params.get('recog_max_len_ratio')) recog_dir += '_lm' + str(recog_params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 wer_dist = {} # calculate WER distribution based on input lengths n_oov_total = 0 wer_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(recog_params.get('recog_batch_size')) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or recog_params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id, aws = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=speakers, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] n_oov_total += nbest_hyps[0].count('<unk>') # Resolving UNK if recog_params.get( 'recog_resolving_unk') and '<unk>' in nbest_hyps[0]: recog_params_char = copy.deepcopy(recog_params) recog_params_char['recog_lm_weight'] = 0 recog_params_char['recog_beam_width'] = 1 best_hyps_id_char, aw_char = models[0].decode( batch['xs'][b:b + 1], recog_params_char, idx2token=dataloader.idx2token[1], exclude_eos=True, refs_id=batch['ys_sub1'], utt_ids=batch['utt_ids'], speakers=speakers, task='ys_sub1') # TODO(hirofumi): support ys_sub2 assert not streaming nbest_hyps[0] = resolve_unk( nbest_hyps[0], best_hyps_id_char[0], aws[b], aw_char[0], dataloader.idx2token[1], subsample_factor_word=np.prod(models[0].subsample), subsample_factor_char=np.prod( models[0].subsample[:models[0].enc_n_layers_sub1 - 1])) logger.debug('Hyp (after OOV resolution): %s' % nbest_hyps[0]) nbest_hyps[0] = nbest_hyps[0].replace('*', '') # Compute CER ref_char = ref hyp_char = nbest_hyps[0] if dataloader.corpus == 'csj': ref_char = ref_char.replace(' ', '') hyp_char = hyp_char.replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref_char), hyp=list(hyp_char)) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref_char) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute oracle WER if oracle and len(nbest_hyps) > 1: wers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(wers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) wer_oracle += wers_b[oracle_idx] # NOTE: OOV resolution is not considered if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [err_b / 100] else: wer_dist[xlen_bin] = [err_b / 100] n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word if n_char > 0: cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if recog_params.get('recog_beam_width') > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('OOV (total): %d' % (n_oov_total)) if oracle: wer_oracle /= n_word oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle WER (%s): %.2f %%' % (dataloader.set, wer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataloader.set, sum(wers) / len(wers), len_bin)) return wer, cer, n_oov_total
def eval_phone(models, dataset, decode_params, epoch, progressbar=False): """Evaluate a phone-level model by PER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): progressbar (bool): if True, visualize the progressbar Returns: per (float): Phone error rate num_sub (int): the number of substitution errors num_ins (int): the number of insertion errors num_del (int): the number of deletion errors decode_dir (str): """ # Reset data counter dataset.reset() model = models[0] # TODO(hirofumi): ensemble decoding decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') per = 0 num_sub, num_ins, num_del = 0, 0, 0 num_phones = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, _, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=True) ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): # Reference if dataset.is_test: text_ref = ys[b] else: text_ref = dataset.idx2phone(ys[b]) # Hypothesis text_hyp = dataset.idx2phone(best_hyps[b]) # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') # Compute PER per_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '), hyp=text_hyp.split(' '), normalize=False) per += per_b num_sub += sub_b num_ins += ins_b num_del += del_b num_phones += len(text_ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() per /= num_phones num_sub /= num_phones num_ins /= num_phones num_del /= num_phones return per, num_sub, num_ins, num_del, os.path.join(model.save_path, decode_dir)
def eval_char(models, dataset, recog_params, epoch, recog_dir=None, progressbar=False, task_idx=0): """Evaluate the character-level model by WER & CER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): progressbar (bool): visualize the progressbar task_idx (int): the index of the target task in interest 0: main task 1: sub task 2: sub sub task Returns: wer (float): Word error rate cer (float): Character error rate """ # Reset data counter dataset.reset() if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) if task_idx == 0: task = 'ys' elif task_idx == 1: task = 'ys_sub1' elif task_idx == 2: task = 'ys_sub2' elif task_idx == 3: task = 'ys_sub3' with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next( recog_params['recog_batch_size']) best_hyps_id, _, _ = models[0].decode( batch['xs'], recog_params, dataset.idx2token[task_idx], exclude_eos=True, refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], task=task, ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[task_idx](best_hyps_id[b]) # Write to trn utt_id = str(batch['utt_ids'][b]) speaker = str(batch['speakers'][b]).replace('-', '_') f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.info('utt-id: %s' % utt_id) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 150) if ('char' in dataset.unit and 'nowb' not in dataset.unit) or ( task_idx > 0 and dataset.unit_sub1 == 'char'): # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute CER if dataset.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if ('char' in dataset.unit and 'nowb' not in dataset.unit) or ( task_idx > 0 and dataset.unit_sub1 == 'char'): wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logger.info('WER (%s): %.2f %%' % (dataset.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataset.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) return wer, cer
def eval_phone(models, dataset, recog_params, epoch, recog_dir=None, progressbar=False): """Evaluate a phone-level model by PER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): progressbar (bool): visualize the progressbar Returns: per (float): Phone error rate """ # Reset data counter dataset.reset() if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') per = 0 n_sub, n_ins, n_del = 0, 0, 0 n_phone = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.getitem( recog_params['recog_batch_size']) best_hyps_id, _, _ = models[0].decode( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions'] if dataset.corpus == 'swbd' else batch['speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn utt_id = str(batch['utt_ids'][b]) speaker = str(batch['speakers'][b]).replace('-', '_') f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 150) # Compute PER per_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) per += per_b n_sub += sub_b n_ins += ins_b n_del += del_b n_phone += len(ref.split(' ')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() per /= n_phone n_sub /= n_phone n_ins /= n_phone n_del /= n_phone logger.info('PER (%s): %.2f %%' % (dataset.set, per)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del)) return per
def eval_char(models, dataloader, params, epoch=-1, rank=0, save_dir=None, streaming=False, progressbar=False, task_idx=0, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a character-level model by WER & CER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch rank (int): rank of current process group save_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) task_idx (int): index of target task in interest 0: main task 1: sub task 2: sub sub task fine_grained (bool): calculate fine-grained WER distributions based on input lengths oracle (bool): calculate oracle WER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate """ if save_dir is None: save_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(params.get('recog_beam_width')) save_dir += '_lp' + str(params.get('recog_length_penalty')) save_dir += '_cp' + str(params.get('recog_coverage_penalty')) save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \ str(params.get('recog_max_len_ratio')) save_dir += '_lm' + str(params.get('recog_lm_weight')) ref_trn_path = mkdir_join(models[0].save_path, save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(models[0].save_path, save_dir, 'hyp.trn', rank=rank) else: ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank) wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 cer_dist = {} # calculate CER distribution based on input lengths cer_oracle = 0 n_oracle_hit = 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 # Reset data counter dataloader.reset(params.get('recog_batch_size'), 'seq') if progressbar: pbar = tqdm(total=len(dataloader)) if rank == 0: f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8') f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8') if task_idx == 0: task = 'ys' elif task_idx == 1: task = 'ys_sub1' elif task_idx == 2: task = 'ys_sub2' elif task_idx == 3: task = 'ys_sub3' for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming(batch['xs'], params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id = models[0].decode( batch['xs'], params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'] if task_idx == 0 else batch['ys_sub' + str(task_idx)], utt_ids=batch['utt_ids'], speakers=speakers, task=task, ensemble_models=models[1:] if len(models) > 1 else [], teacher_force=teacher_force)[0] for b in range(len(batch['xs'])): # assert len(batch['xs']) == 1, 'batch is 1' ref = batch['text'][b] nbest_hyps_tmp = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # print(nbest_hyps_id) # print(nbest_hyps_tmp) # assert False, 'vv' # Truncate the first and last spaces for the char_space unit nbest_hyps = [] for hyp in nbest_hyps_tmp: if len(hyp) > 0 and hyp[0] == ' ': hyp = hyp[1:] if len(hyp) > 0 and hyp[-1] == ' ': hyp = hyp[:-1] nbest_hyps.append(hyp) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) if rank == 0: f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit ) or (task_idx > 0 and dataloader.unit_sub1 == 'char'): # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # NOTE: sentence error rate for Chinese # Compute CER if dataloader.corpus == 'csj': ref = ref.replace(' ', '') nbest_hyps[0] = nbest_hyps[0].replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list( nbest_hyps[0])) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) # Compute oracle CER if oracle and len(nbest_hyps) > 1: cers_b = [err_b] + [ compute_wer(ref=list(ref), hyp=list(hyp_n))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(cers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) cer_oracle += cers_b[oracle_idx] if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in cer_dist.keys(): cer_dist[xlen_bin] += [err_b / 100] else: cer_dist[xlen_bin] = [err_b / 100] if models[0].streamable(): n_streamable += len(batch['utt_ids']) else: last_success_frame_ratio += models[ 0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if rank == 0: f_hyp.close() f_ref.close() if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: if ('char' in dataloader.unit and 'nowb' not in dataloader.unit) or ( task_idx > 0 and dataloader.unit_sub1 == 'char'): wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt if params.get('recog_beam_width') > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) if oracle: cer_oracle /= n_char oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle CER (%s): %.2f %%' % (dataloader.set, cer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, cers in sorted(cer_dist.items(), key=lambda x: x[0]): logger.info(' CER (%s): %.2f %% (%d)' % (dataloader.set, sum(cers) / len(cers), len_bin)) logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio)) return wer, cer
def eval_wordpiece(models, dataloader, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a wordpiece-level model by WER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar oracle (bool): calculate oracle WER fine_grained (bool): calculate fine-grained WER distributions based on input lengths teacher_force (bool): conduct decoding in teacher-forcing mode Returns: wer (float): Word error rate cer (float): Character error rate """ if recog_dir is None: recog_dir = 'decode_' + dataloader.set + '_ep' + str( epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str( recog_params['recog_min_len_ratio']) + '_' + str( recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 wer_dist = {} # calculate WER distribution based on input lengths wer_oracle = 0 n_oracle_hit = 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 # Reset data counter dataloader.reset(recog_params['recog_batch_size']) if progressbar: pbar = tqdm(total=len(dataloader)) with codecs.open(hyp_trn_path, 'w', encoding='utf-8') as f_hyp, \ codecs.open(ref_trn_path, 'w', encoding='utf-8') as f_ref: while True: batch, is_new_epoch = dataloader.next( recog_params['recog_batch_size']) if streaming or recog_params['recog_block_sync']: nbest_hyps_id = models[0].decode_streaming( batch['xs'], recog_params, dataloader.idx2token[0], exclude_eos=True)[0] else: nbest_hyps_id = models[0].decode( batch['xs'], recog_params, idx2token=dataloader.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else [])[0] for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if not streaming: # Compute WER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) wer += err_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute oracle WER if oracle and len(nbest_hyps) > 1: wers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(wers_b)) if oracle_idx == 0: n_oracle_hit += 1 wer_oracle += wers_b[oracle_idx] if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [err_b / 100] else: wer_dist[xlen_bin] = [err_b / 100] # Compute CER if dataloader.corpus == 'csj': ref = ref.replace(' ', '') nbest_hyps[0] = nbest_hyps[0].replace(' ', '') err_b, sub_b, ins_b, del_b = compute_wer( ref=list(ref), hyp=list(nbest_hyps[0])) cer += err_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if models[0].streamable(): n_streamable += 1 else: last_success_frame_ratio += models[ 0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataloader.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt if recog_params['recog_beam_width'] > 1: logger.info('WER (%s): %.2f %%' % (dataloader.set, wer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.info('CER (%s): %.2f %%' % (dataloader.set, cer)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) if oracle: wer_oracle /= n_word oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle WER (%s): %.2f %%' % (dataloader.set, wer_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataloader.set, sum(wers) / len(wers), len_bin)) logger.info('Streamability (%s): %.2f %%' % (dataloader.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataloader.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataloader.set, last_success_frame_ratio)) return wer, cer
def eval_wordpiece(models, dataset, decode_params, epoch, decode_dir=None, progressbar=False): """Evaluate the wordpiece-level model by WER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): decode_dir (str): progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate nsub (int): the number of substitution errors nins (int): the number of insertion errors ndel (int): the number of deletion errors """ # Reset data counter dataset.reset() model = models[0] # TODO(hirofumi): ensemble decoding if decode_dir is None: decode_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str( decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn') wer = 0 nsub, nins, ndel = 0, 0, 0 nword = 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, _, perm_id = model.decode(batch['xs'], decode_params, exclude_eos=True, id2token=dataset.id2wp, refs=batch['ys']) ys = [batch['text'][i] for i in perm_id] for b in six.moves.range(len(batch['xs'])): ref = ys[b] hyp = dataset.id2wp(best_hyps[b]) # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) # logger.info('Ref: %s' % ref.lower()) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b nsub += sub_b nins += ins_b ndel += del_b nword += len(ref.split(' ')) # logger.info('WER: %d%%' % (float(wer_b) / len(ref.split(' ')))) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= nword nsub /= nword nins /= nword ndel /= nword return wer, nsub, nins, ndel
def eval_char(models, dataset, decode_params, epoch, decode_dir=None, progressbar=False, task_id=0): """Evaluate the character-level model by WER & CER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): decode_dir (str): progressbar (bool): if True, visualize the progressbar task_id (int): the index of the target task in interest 0: main task 1: sub task 2: sub sub task Returns: wer (float): Word error rate nsub_w (int): the number of substitution errors for WER nins_w (int): the number of insertion errors for WER ndel_w (int): the number of deletion errors for WER cer (float): Character error rate nsub_w (int): the number of substitution errors for CER nins_c (int): the number of insertion errors for CER ndel_c (int): the number of deletion errors for CER """ # Reset data counter dataset.reset() model = models[0] if decode_dir is None: decode_dir = 'decode_' + dataset.set + '_ep' + str( epoch) + '_beam' + str(decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str( decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(decode_dir, 'hyp.trn') wer, cer = 0, 0 nsub_w, nins_w, ndel_w = 0, 0, 0 nsub_c, nins_c, ndel_c = 0, 0, 0 nword, nchar = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) if task_id == 0: task = 'ys' elif task_id == 1: task = 'ys_sub1' elif task_id == 2: task = 'ys_sub2' with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, _, perm_ids = model.decode(batch['xs'], decode_params, exclude_eos=True, task=task) ys = [batch['text'][i] for i in perm_ids] for b in six.moves.range(len(batch['xs'])): ref = ys[b] hyp = dataset.id2char(best_hyps[b]) # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace( '-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') logger.info('utt-id: %s' % batch['utt_ids'][b]) # logger.info('Ref: %s' % ref.lower()) logger.info('Ref: %s' % ref) logger.info('Hyp: %s' % hyp) logger.info('-' * 50) if ('char' in dataset.unit and 'nowb' not in dataset.unit) or ( task_id > 0 and dataset.unit_sub1 == 'char'): # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b nsub_w += sub_b nins_w += ins_b ndel_w += del_b nword += len(ref.split(' ')) # logger.info('WER: %d%%' % (wer_b / len(ref.split(' ')))) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b nsub_c += sub_b nins_c += ins_b ndel_c += del_b nchar += len(ref) # logger.info('CER: %d%%' % (cer_b / len(ref))) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if ('char' in dataset.unit and 'nowb' not in dataset.unit) or ( task_id > 0 and dataset.unit_sub1 == 'char'): wer /= nword nsub_w /= nword nins_w /= nword ndel_w /= nword else: wer = nsub_w = nins_w = ndel_w = 0 cer /= nchar nsub_c /= nchar nins_c /= nchar ndel_c /= nchar return (wer, nsub_w, nins_w, ndel_w), (cer, nsub_c, nins_c, ndel_c)
def eval_wordpiece(models, dataset, recog_params, epoch, recog_dir=None, streaming=False, progressbar=False, fine_grained=False): """Evaluate the wordpiece-level model by WER. Args: models (list): models to evaluate dataset (Dataset): evaluation dataset recog_params (dict): epoch (int): recog_dir (str): streaming (bool): streaming decoding for the session-level evaluation progressbar (bool): visualize the progressbar fine_grained (bool): calculate fine-grained WER distributions based on input lengths Returns: wer (float): Word error rate cer (float): Character error rate """ # Reset data counter dataset.reset(recog_params['recog_batch_size']) if recog_dir is None: recog_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(recog_params['recog_beam_width']) recog_dir += '_lp' + str(recog_params['recog_length_penalty']) recog_dir += '_cp' + str(recog_params['recog_coverage_penalty']) recog_dir += '_' + str(recog_params['recog_min_len_ratio']) + '_' + str(recog_params['recog_max_len_ratio']) recog_dir += '_lm' + str(recog_params['recog_lm_weight']) ref_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(models[0].save_path, recog_dir, 'hyp.trn') else: ref_trn_save_path = mkdir_join(recog_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(recog_dir, 'hyp.trn') wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 n_streamable, quantity_rate, n_utt = 0, 0, 0 last_success_frame_ratio = 0 if progressbar: pbar = tqdm(total=len(dataset)) # calculate WER distribution based on input lengths wer_dist = {} with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(recog_params['recog_batch_size']) if streaming or recog_params['recog_chunk_sync']: best_hyps_id, _ = models[0].decode_streaming( batch['xs'], recog_params, dataset.idx2token[0], exclude_eos=True) else: best_hyps_id, _ = models[0].decode( batch['xs'], recog_params, idx2token=dataset.idx2token[0] if progressbar else None, exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=batch['sessions' if dataset.corpus == 'swbd' else 'speakers'], ensemble_models=models[1:] if len(models) > 1 else []) for b in range(len(batch['xs'])): ref = batch['text'][b] if ref[0] == '<': ref = ref.split('>')[1] hyp = dataset.idx2token[0](best_hyps_id[b]) # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id: %s' % utt_id) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % hyp) logger.debug('-' * 150) if not streaming: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) if fine_grained: xlen_bin = (batch['xlens'][b] // 200 + 1) * 200 if xlen_bin in wer_dist.keys(): wer_dist[xlen_bin] += [wer_b / 100] else: wer_dist[xlen_bin] = [wer_b / 100] # Compute CER if dataset.corpus == 'csj': ref = ref.replace(' ', '') hyp = hyp.replace(' ', '') cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) if models[0].streamable(): n_streamable += 1 else: last_success_frame_ratio += models[0].last_success_frame_ratio() quantity_rate += models[0].quantity_rate() n_utt += 1 if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if not streaming: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char if n_utt - n_streamable > 0: last_success_frame_ratio /= (n_utt - n_streamable) n_streamable /= n_utt quantity_rate /= n_utt if fine_grained: for len_bin, wers in sorted(wer_dist.items(), key=lambda x: x[0]): logger.info(' WER (%s): %.2f %% (%d)' % (dataset.set, sum(wers) / len(wers), len_bin)) logger.debug('WER (%s): %.2f %%' % (dataset.set, wer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logger.debug('CER (%s): %.2f %%' % (dataset.set, cer)) logger.debug('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) logger.info('Streamablility (%s): %.2f %%' % (dataset.set, n_streamable * 100)) logger.info('Quantity rate (%s): %.2f %%' % (dataset.set, quantity_rate * 100)) logger.info('Last success frame ratio (%s): %.2f %%' % (dataset.set, last_success_frame_ratio)) return wer, cer
def eval_char(models, dataset, decode_params, epoch, progressbar=False): """Evaluate the character-level model by WER & CER. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class decode_params (dict): epoch (int): progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate num_sub (int): the number of substitution errors num_ins (int): the number of insertion errors num_del (int): the number of deletion errors cer (float): Character error rate num_sub (int): the number of substitution errors num_ins (int): the number of insertion errors num_del (int): the number of deletion errors decode_dir (str): """ # Reset data counter dataset.reset() model = models[0] decode_dir = 'decode_' + dataset.set + '_ep' + str(epoch) + '_beam' + str(decode_params['beam_width']) decode_dir += '_lp' + str(decode_params['length_penalty']) decode_dir += '_cp' + str(decode_params['coverage_penalty']) decode_dir += '_' + str(decode_params['min_len_ratio']) + '_' + str(decode_params['max_len_ratio']) decode_dir += '_rnnlm' + str(decode_params['rnnlm_weight']) ref_trn_save_path = mkdir_join(model.save_path, decode_dir, 'ref.trn') hyp_trn_save_path = mkdir_join(model.save_path, decode_dir, 'hyp.trn') wer, cer = 0, 0 num_sub_w, num_ins_w, num_del_w = 0, 0, 0 num_sub_c, num_ins_c, num_del_c = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) with open(hyp_trn_save_path, 'w') as f_hyp, open(ref_trn_save_path, 'w') as f_ref: while True: batch, is_new_epoch = dataset.next(decode_params['batch_size']) best_hyps, aw, perm_idx = model.decode(batch['xs'], decode_params, exclude_eos=True) # task_index = 0 ys = [batch['ys'][i] for i in perm_idx] for b in range(len(batch['xs'])): # Reference if dataset.is_test: text_ref = ys[b] else: text_ref = dataset.idx2char(ys[b]) # Hypothesis text_hyp = dataset.idx2char(best_hyps[b]) # Write to trn speaker = '_'.join(batch['utt_ids'][b].replace('-', '_').split('_')[:-2]) start = batch['utt_ids'][b].replace('-', '_').split('_')[-2] end = batch['utt_ids'][b].replace('-', '_').split('_')[-1] f_ref.write(text_ref + ' (' + speaker + '-' + start + '-' + end + ')\n') f_hyp.write(text_hyp + ' (' + speaker + '-' + start + '-' + end + ')\n') if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'): # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=text_ref.split(' '), hyp=text_hyp.split(' '), normalize=False) wer += wer_b num_sub_w += sub_b num_ins_w += ins_b num_del_w += del_b num_words += len(text_ref.split(' ')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(text_ref.replace(' ', '')), hyp=list(text_hyp.replace(' ', '')), normalize=False) cer += cer_b num_sub_c += sub_b num_ins_c += ins_b num_del_c += del_b num_chars += len(text_ref) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() if ('character' in dataset.label_type and 'nowb' not in dataset.label_type) or (task_index > 0 and dataset.label_type_sub == 'character'): wer /= num_words num_sub_w /= num_words num_ins_w /= num_words num_del_w /= num_words else: wer = num_sub_w = num_ins_w = num_del_w = 0 cer /= num_chars num_sub_c /= num_chars num_ins_c /= num_chars num_del_c /= num_chars return (wer, num_sub_w, num_ins_w, num_del_w), (cer, num_sub_c, num_ins_c, num_del_c), os.path.join(model.save_path, decode_dir)
def eval(epoch): recog_dir = args.out ref_trn_save_path = recog_dir + '/ref_epoch_' + str(epoch) + '.trn' hyp_trn_save_path = recog_dir + '/hyp_epoch_' + str(epoch) + '.trn' wer, cer = 0, 0 n_sub_w, n_ins_w, n_del_w = 0, 0, 0 n_sub_c, n_ins_c, n_del_c = 0, 0, 0 n_word, n_char = 0, 0 pbar = tqdm(total=len(devset)) f_hyp = open(hyp_trn_save_path, 'w') f_ref = open(ref_trn_save_path, 'w') losses = [] is_new_epoch = 0 # for xs, ys, xlen, ylen in devset: step = 0 while True: batch, is_new_epoch = devset.next() # if is_new_epoch: # break xs, ys, xlens = batch['xs'], batch['ys'], batch['xlens'] xs = [np2tensor(x).float() for x in batch['xs']] xlen = torch.IntTensor([len(x) for x in batch['xs']]) xs = pad_list(xs, 0.0).cuda() _ys = [np2tensor(np.fromiter(y, dtype=np.int64), -1) for y in ys] ys_out_pad = pad_list(_ys, 0).long().cuda() ylen = np2tensor(np.fromiter([y.size(0) for y in _ys], dtype=np.int32)) # xs = Variable(torch.FloatTens is:open or(xs), volatile=True).cuda() # ys = Variable(torch.LongTensor(ys), volatile=True).cuda() # xlen = Variable(torch.IntTensor(xlen)); ylen = Variable(torch.IntTensor(ylen)) model.eval() #logging.info('================== Evaluation Mode =================') loss = model(xs, ys_out_pad, xlen, ylen) loss = float(loss.data) * len(xlen) losses.append(loss) step += 1 # //TODO vishay un-hardcode the batch size best_hyps_id, _ = model.greedy_decode(xs) for b in range(len(batch['xs'])): ref = batch['text'][b] hyp = devset.idx2token[0](best_hyps_id[b]) hyp = removeDuplicates(hyp) # Write to trn utt_id = str(batch['utt_ids'][b]) speaker = str(batch['speakers'][b]).replace('-', '_') if hyp is None: hyp = "none" f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(hyp + ' (' + speaker + '-' + utt_id + ')\n') logging.info('utt-id: %s' % utt_id) logging.info('Ref: %s' % ref) logging.info('Hyp: %s' % hyp) logging.info('-' * 150) if 'char' in devset.unit: # //TODO this is only for char unit # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=ref.split(' '), hyp=hyp.split(' '), normalize=False) wer += wer_b n_sub_w += sub_b n_ins_w += ins_b n_del_w += del_b n_word += len(ref.split(' ')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer(ref=list(ref), hyp=list(hyp), normalize=False) cer += cer_b n_sub_c += sub_b n_ins_c += ins_b n_del_c += del_b n_char += len(ref) pbar.update(len(batch['xs'])) if is_new_epoch: break pbar.close() # Reset data counters devset.reset() if 'char' in devset.unit: wer /= n_word n_sub_w /= n_word n_ins_w /= n_word n_del_w /= n_word else: wer = n_sub_w = n_ins_w = n_del_w = 0 cer /= n_char n_sub_c /= n_char n_ins_c /= n_char n_del_c /= n_char logging.info('WER (%s): %.2f %%' % (devset.set, wer)) logging.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_w, n_ins_w, n_del_w)) logging.info('CER (%s): %.2f %%' % (devset.set, cer)) logging.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub_c, n_ins_c, n_del_c)) # print(step, '/12k dev') return sum(losses) / len(devset), wer, cer
def eval_phone(models, dataloader, params, epoch=-1, rank=0, save_dir=None, streaming=False, progressbar=False, edit_distance=True, fine_grained=False, oracle=False, teacher_force=False): """Evaluate a phone-level model by PER. Args: models (List): models to evaluate dataloader (torch.utils.data.DataLoader): evaluation dataloader params (omegaconf.dictconfig.DictConfig): decoding hyperparameters epoch (int): current epoch rank (int): rank of current process group save_dir (str): directory path to save hypotheses streaming (bool): streaming decoding for session-level evaluation progressbar (bool): visualize progressbar edit_distance (bool): calculate edit-distance (can be skipped for RTF calculation) fine_grained (bool): calculate fine-grained PER distributions based on input lengths oracle (bool): calculate oracle PER teacher_force (bool): conduct decoding in teacher-forcing mode Returns: per (float): Phone error rate """ if save_dir is None: save_dir = 'decode_' + dataloader.set + '_ep' + \ str(epoch) + '_beam' + str(params.get('recog_beam_width')) save_dir += '_lp' + str(params.get('recog_length_penalty')) save_dir += '_cp' + str(params.get('recog_coverage_penalty')) save_dir += '_' + str(params.get('recog_min_len_ratio')) + '_' + \ str(params.get('recog_max_len_ratio')) ref_trn_path = mkdir_join(models[0].save_path, save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(models[0].save_path, save_dir, 'hyp.trn', rank=rank) else: ref_trn_path = mkdir_join(save_dir, 'ref.trn', rank=rank) hyp_trn_path = mkdir_join(save_dir, 'hyp.trn', rank=rank) per = 0 n_sub, n_ins, n_del = 0, 0, 0 n_phone = 0 per_dist = {} # calculate PER distribution based on input lengths per_oracle = 0 n_oracle_hit = 0 n_utt = 0 # Reset data counter dataloader.reset(params.get('recog_batch_size'), 'seq') if progressbar: pbar = tqdm(total=len(dataloader)) if rank == 0: f_hyp = codecs.open(hyp_trn_path, 'w', encoding='utf-8') f_ref = codecs.open(ref_trn_path, 'w', encoding='utf-8') for batch in dataloader: speakers = batch['sessions' if dataloader.corpus == 'swbd' else 'speakers'] if streaming or params.get('recog_block_sync'): nbest_hyps_id = models[0].decode_streaming(batch['xs'], params, dataloader.idx2token[0], exclude_eos=True, speaker=speakers[0])[0] else: nbest_hyps_id = models[0].decode( batch['xs'], params, idx2token=dataloader.idx2token[0], exclude_eos=True, refs_id=batch['ys'], utt_ids=batch['utt_ids'], speakers=speakers, ensemble_models=models[1:] if len(models) > 1 else [], teacher_force=teacher_force)[0] for b in range(len(batch['xs'])): ref = batch['text'][b] nbest_hyps = [ dataloader.idx2token[0](hyp_id) for hyp_id in nbest_hyps_id[b] ] # Write to trn speaker = str(batch['speakers'][b]).replace('-', '_') if streaming: utt_id = str(batch['utt_ids'][b]) + '_0000000_0000001' else: utt_id = str(batch['utt_ids'][b]) if rank == 0: f_ref.write(ref + ' (' + speaker + '-' + utt_id + ')\n') f_hyp.write(nbest_hyps[0] + ' (' + speaker + '-' + utt_id + ')\n') logger.debug('utt-id (%d/%d): %s' % (n_utt + 1, len(dataloader), utt_id)) logger.debug('Ref: %s' % ref) logger.debug('Hyp: %s' % nbest_hyps[0]) logger.debug('-' * 150) if edit_distance and not streaming: # Compute PER err_b, sub_b, ins_b, del_b = compute_wer( ref=ref.split(' '), hyp=nbest_hyps[0].split(' ')) per += err_b n_sub += sub_b n_ins += ins_b n_del += del_b n_phone += len(ref.split(' ')) # Compute oracle PER if oracle and len(nbest_hyps) > 1: pers_b = [err_b] + [ compute_wer(ref=ref.split(' '), hyp=hyp_n.split(' '))[0] for hyp_n in nbest_hyps[1:] ] oracle_idx = np.argmin(np.array(pers_b)) if oracle_idx == 0: n_oracle_hit += len(batch['utt_ids']) per_oracle += pers_b[oracle_idx] n_utt += len(batch['utt_ids']) if progressbar: pbar.update(len(batch['utt_ids'])) if rank == 0: f_hyp.close() f_ref.close() if progressbar: pbar.close() # Reset data counters dataloader.reset(is_new_epoch=True) if edit_distance and not streaming: per /= n_phone n_sub /= n_phone n_ins /= n_phone n_del /= n_phone if params.get('recog_beam_width') > 1: logger.info('PER (%s): %.2f %%' % (dataloader.set, per)) logger.info('SUB: %.2f / INS: %.2f / DEL: %.2f' % (n_sub, n_ins, n_del)) if oracle: per_oracle /= n_phone oracle_hit_rate = n_oracle_hit * 100 / n_utt logger.info('Oracle PER (%s): %.2f %%' % (dataloader.set, per_oracle)) logger.info('Oracle hit rate (%s): %.2f %%' % (dataloader.set, oracle_hit_rate)) if fine_grained: for len_bin, pers in sorted(per_dist.items(), key=lambda x: x[0]): logger.info(' PER (%s): %.2f %% (%d)' % (dataloader.set, sum(pers) / len(pers), len_bin)) return per