def eval_word(models, dataset, beam_width, max_decode_len, beam_width_sub=1, max_decode_len_sub=300, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1, resolving_unk=False, a2c_oracle=False): """Evaluate trained model by Word Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class max_decode_len (int): the length of output sequences to stop prediction. This is used for seq2seq models. beam_width_sub (int, optional): the size of beam in ths sub task This is used for the nested attention max_decode_len_sub (int, optional): the length of output sequences to stop prediction. This is used for the nested attention eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar temperature (int, optional): resolving_unk (bool, optional): a2c_oracle (bool, optional): Returns: wer (float): Word error rate df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2word = Idx2word(dataset.vocab_file_path) if models[0].model_type == 'nested_attention': char2idx = Char2idx(dataset.vocab_file_path_sub) if models[0] in ['ctc', 'attention'] and resolving_unk: idx2char = Idx2char(dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) batch_size = len(batch['xs']) # Decode if len(models) > 1: assert models[0].model_type in ['ctc'] for i, model in enumerate(models): probs, x_lens, perm_idx = model.posteriors( batch['xs'], batch['x_lens']) if i == 0: probs_ensenmble = probs else: probs_ensenmble += probs probs_ensenmble /= len(models) best_hyps = models[0].decode_from_probs(probs_ensenmble, x_lens, beam_width=1) else: model = models[0] # TODO: fix this if model.model_type == 'nested_attention': if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list( batch['ys_sub'][b][0])): max_label_num = len(list( batch['ys_sub'][b][0])) ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size, ), dtype=np.int32) for b in range(batch_size): indices = char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=max_decode_len, max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub, length_penalty=length_penalty, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) if resolving_unk: best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, length_penalty=length_penalty, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(batch_size): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2word(best_hyps[b]) if dataset.label_type == 'word': str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp) else: str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Resolving UNK ############################## if resolving_unk and 'OOV' in str_hyp: str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) str_hyp = str_hyp.replace('*', '') ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) # Compute WER try: wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_wer = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_wer
def main(): args = parser.parse_args() # Load a config file (.yml) params = load_config(join(args.model_path, 'config.yml'), is_eval=True) # Load dataset dataset = Dataset( data_save_path=args.data_save_path, backend=params['backend'], input_freq=params['input_freq'], use_delta=params['use_delta'], use_double_delta=params['use_double_delta'], data_type='eval1', # data_type='eval2', # data_type='eval3', data_size=params['data_size'], label_type=params['label_type'], label_type_sub=params['label_type_sub'], batch_size=args.eval_batch_size, splice=params['splice'], num_stack=params['num_stack'], num_skip=params['num_skip'], sort_utt=False, reverse=False, tool=params['tool']) params['num_classes'] = dataset.num_classes params['num_classes_sub'] = dataset.num_classes_sub # Load model model = load(model_type=params['model_type'], params=params, backend=params['backend']) # Restore the saved parameters model.load_checkpoint(save_path=args.model_path, epoch=args.epoch) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') ###################################################################### word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub) for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, beam_width_sub=args.beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, max_decode_len=MAX_DECODE_LEN_WORD, min_decode_len=MIN_DECODE_LEN_WORD, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, min_decode_len=MIN_DECODE_LEN_CHAR, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty, task_index=1) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, max_decode_len=MAX_DECODE_LEN_WORD, min_decode_len=MIN_DECODE_LEN_WORD, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty, joint_decoding=args.joint_decoding, space_index=dataset.char2idx('_')[0], oov_index=dataset.word2idx('OOV')[0], word2char=word2char, idx2word=dataset.idx2word, idx2char=dataset.idx2char, score_sub_weight=args.score_sub_weight) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = dataset.idx2word(ys[b][:y_lens[b]]) str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = dataset.idx2word(best_hyps[b]) str_hyp_sub = dataset.idx2char(best_hyps_sub[b]) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: str_hyp_joint = dataset.idx2word(best_hyps_joint[b]) str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b]) ############################## # Resolving UNK ############################## if 'OOV' in str_hyp and args.resolving_unk: str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: if 'OOV' in str_hyp_joint and args.resolving_unk: str_hyp_no_unk_joint = resolve_unk(str_hyp_joint, best_hyps_sub_joint[b], aw_joint[b], aw_sub_joint[b], dataset.idx2char) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref.replace('_', ' ')) print('Hyp (main) : %s' % str_hyp.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub.replace('_', ' ')) if 'OOV' in str_hyp and args.resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' ')) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: print('===== joint decoding =====') print('Hyp (main) : %s' % str_hyp_joint.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub_joint.replace('_', ' ')) if 'OOV' in str_hyp_joint and args.resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk_joint.replace('_', ' ')) try: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp).split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer * 100)) if dataset.label_type_sub == 'character_wb': wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'), hyp=re.sub( r'(.*)>(.*)', r'\1', str_hyp_sub).split('_'), normalize=True) print('WER (sub) : %.3f %%' % (wer_sub * 100)) else: cer, _, _, _ = compute_wer( ref=list(str_ref_sub.replace('_', '')), hyp=list( re.sub(r'(.*)>(.*)', r'\1', str_hyp_sub).replace('_', '')), normalize=True) print('CER (sub) : %.3f %%' % (cer * 100)) if 'OOV' in str_hyp and args.resolving_unk: wer_no_unk, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_no_unk.replace('*', '')).split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk * 100)) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: print('===== joint decoding =====') wer_joint, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_joint).split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer_joint * 100)) if 'OOV' in str_hyp_joint and args.resolving_unk: wer_no_unk_joint, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_no_unk_joint.replace( '*', '')).split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk_joint * 100)) except: print('--- skipped ---') print('\n') if is_new_epoch: break
def decode(model, dataset, beam_width, beam_width_sub, eval_batch_size=None, save_path=None, resolving_unk=False): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam in the main task beam_width: (int): the size of beam in the sub task eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results resolving_unk (bool, optional): """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') # Read GLM file glm = GLM( glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_WORD) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref_original = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref_original = idx2word(ys[b][: y_lens[b]]) str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) ############################## # Resolving UNK ############################## if 'OOV' in str_hyp and resolving_unk: str_hyp_no_unk = resolve_unk( str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) # if 'OOV' not in str_hyp: # continue ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref_original, glm) str_ref_sub = fix_trans(str_ref_sub, glm) str_hyp = fix_trans(str_hyp, glm) str_hyp_sub = fix_trans(str_hyp_sub, glm) str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm) if len(str_ref) == 0: continue print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref.replace('_', ' ')) print('Hyp (main) : %s' % str_hyp.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub.replace('_', ' ')) if 'OOV' in str_hyp and resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' ')) try: # Compute WER wer, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.replace(r'_>.*', '').split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer * 100)) wer_sub, _, _, _ = compute_wer( ref=str_ref_sub.split('_'), hyp=str_hyp_sub.replace(r'>.*', '').split('_'), normalize=True) print('WER (sub) : %.3f %%' % (wer_sub * 100)) if 'OOV' in str_hyp and resolving_unk: wer_no_unk, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp_no_unk.replace( '*', '').replace(r'_>.*', '').split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk * 100)) except: print('--- skipped ---') if is_new_epoch: break
def eval_word(models, dataset, eval_batch_size, beam_width, max_decode_len, min_decode_len=0, beam_width_sub=1, max_decode_len_sub=200, min_decode_len_sub=0, length_penalty=0, coverage_penalty=0, progressbar=False, resolving_unk=False, a2c_oracle=False, joint_decoding=None, score_sub_weight=0): """Evaluate trained model by Word Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class eval_batch_size (int): the batch size when evaluating the model beam_width (int): the size of beam in ths main task max_decode_len (int): the maximum sequence length of tokens in the main task min_decode_len (int): the minimum sequence length of tokens in the main task beam_width_sub (int): the size of beam in ths sub task This is used for the nested attention max_decode_len_sub (int): the maximum sequence length of tokens in the sub task min_decode_len_sub (int): the minimum sequence length of tokens in the sub task length_penalty (float): length penalty in beam search decoding coverage_penalty (float): coverage penalty in beam search decoding progressbar (bool): if True, visualize the progressbar resolving_unk (bool): a2c_oracle (bool): joint_decoding (bool): onepass or resocring or None score_sub_weight (float): Returns: wer (float): Word error rate df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() model = models[0] # TODO: fix this if model.model_type == 'hierarchical_attention' and joint_decoding is not None: word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub) wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) batch_size = len(batch['xs']) # Decode if model.model_type == 'nested_attention': if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list(batch['ys_sub'][b][0])): max_label_num = len( list(batch['ys_sub'][b][0])) ys_sub = np.zeros( (batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size,), dtype=np.int32) for b in range(batch_size): indices = dataset.char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, beam_width_sub=beam_width_sub, max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub, min_decode_len_sub=min_decode_len_sub, length_penalty=length_penalty, coverage_penalty=coverage_penalty, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) elif model.model_type == 'hierarchical_attention' and joint_decoding is not None: best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty, joint_decoding=joint_decoding, space_index=dataset.char2idx('_')[0], oov_index=dataset.word2idx('OOV')[0], word2char=word2char, idx2word=dataset.idx2word, idx2char=dataset.idx2char, score_sub_weight=score_sub_weight) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty) if resolving_unk: best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, min_decode_len=min_decode_len_sub, length_penalty=length_penalty, coverage_penalty=coverage_penalty, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(batch_size): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = dataset.idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = dataset.idx2word(best_hyps[b]) if dataset.label_type == 'word': str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp) else: str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Resolving UNK ############################## if resolving_unk and 'OOV' in str_hyp: str_hyp = resolve_unk( str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char) str_hyp = str_hyp.replace('*', '') ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) # Compute WER try: wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_word = pd.DataFrame( {'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100]}, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_word