def do_eval_wer(model, dataset, beam_width, max_decode_len, eval_batch_size=None, progressbar=False): """Evaluate trained model by Word Error Rate. Args: model: the model to evaluate dataset: An instance of a `Dataset' class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar Returns: wer (float): Word error rate df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2word(best_hyps[b]) if 'attention' in model.model_type: str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) # TODO: WER計算するときに消していい? # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_wer = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_wer
def do_eval_cer(save_paths, dataset, data_type, label_type, num_classes, beam_width, temperature_infer, is_test=False, progressbar=False): """Evaluate trained model by Character Error Rate. Args: save_paths (list): dataset: An instance of a `Dataset` class data_type (string): label_type (string): character num_classes (int): beam_width (int): the size of beam temperature (int): temperature in the inference stage is_test (bool, optional): set to True when evaluating by the test set progressbar (bool, optional): if True, visualize the progressbar Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') char2idx = Char2idx( map_file_path='../metrics/mapping_files/character.txt') else: raise TypeError # Define decoder decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0], blank_index=num_classes - 1) ################################################## # Compute mean probabilities ################################################## if progressbar: pbar = tqdm(total=len(dataset)) cer_mean, wer_mean = 0, 0 for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data batch_size = inputs[0].shape[0] for i_batch in range(batch_size): probs_ensemble = None for i_model in range(len(save_paths)): # Load posteriors speaker = input_names[0][i_batch].split('-')[0] prob_save_path = join(save_paths[i_model], 'temp' + str(temperature_infer), data_type, 'probs_utt', speaker, input_names[0][i_batch] + '.npy') probs_model_i = np.load(prob_save_path) # NOTE: probs_model_i: `[T, num_classes]` # Sum over probs if probs_ensemble is None: probs_ensemble = probs_model_i else: probs_ensemble += probs_model_i # Compute mean posteriors probs_ensemble /= len(save_paths) # Decode per utterance labels_pred, scores = decoder( probs=probs_ensemble[np.newaxis, :, :], seq_len=inputs_seq_len[0][i_batch:i_batch + 1], beam_width=beam_width) # Convert from list of index to string if is_test: str_true = labels_true[0][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[0][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[0]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_pred.split('_'), hyp=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_true.split('_'), # hyp=str_pred.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= (len(dataset)) wer_mean /= (len(dataset)) # TODO: Fix this return cer_mean, wer_mean
def eval_phone(model, dataset, map_file_path, eval_batch_size, beam_width, max_decode_len, min_decode_len=0, length_penalty=0, coverage_penalty=0, progressbar=False): """Evaluate trained model by Phone Error Rate. Args: model: the model to evaluate dataset: An instance of a `Dataset' class map_file_path (string): path to phones.60-48-39.map eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction. This is used for seq2seq models. min_decode_len (int): the minimum sequence length to emit length_penalty (float): length penalty in beam search decoding coverage_penalty (float): coverage penalty in beam search decoding progressbar (bool): if True, visualize the progressbar Returns: per (float): Phone error rate df_phone (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() map2phone39 = Map2phone39(label_type=dataset.label_type, map_file_path=map_file_path) per = 0 sub, ins, dele = 0, 0, 0 num_phones = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: phone_ref_list = ys[b][0].split(' ') # NOTE: transcript is seperated by space(' ') else: # Convert from index to phone (-> list of phone strings) phone_ref_list = dataset.idx2phone( ys[b][:y_lens[b]]).split(' ') ############################## # Hypothesis ############################## # Convert from index to phone (-> list of phone strings) str_hyp = dataset.idx2phone(best_hyps[b]) str_hyp = re.sub(r'(.*) >(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> phone_hyp_list = str_hyp.split(' ') # Mapping to 39 phones (-> list of phone strings) if dataset.label_type != 'phone39': phone_ref_list = map2phone39(phone_ref_list) phone_hyp_list = map2phone39(phone_hyp_list) # Compute PER try: per_b, sub_b, ins_b, del_b = compute_wer(ref=phone_ref_list, hyp=phone_hyp_list, normalize=False) per += per_b sub += sub_b ins += ins_b dele += del_b num_phones += len(phone_ref_list) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() per /= num_phones sub /= num_phones ins /= num_phones dele /= num_phones df_phone = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['PER']) return per, df_phone
def decode(model, dataset, beam_width, eval_batch_size=None, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if 'char' in dataset.label_type: map_fn = Idx2char( dataset.vocab_file_path, capital_divide=dataset.label_type == 'character_capital_divide') max_decode_len = MAX_DECODE_LEN_CHAR else: map_fn = Idx2word(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_WORD # Read GLM file glm = GLM( glm_path= '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm' ) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, _, best_hyps_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=max_decode_len) else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref_original = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref_original = map_fn(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = map_fn(best_hyps[b]) ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref_original, glm) str_hyp = fix_trans(str_hyp, glm) if len(str_ref) == 0: continue print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp: %s' % str_hyp.replace('_', ' ')) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = map_fn(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) try: # Compute WER wer, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.replace(r'_>.*', '').replace(r'>.*', '').split('_'), normalize=True) print('WER: %.3f %%' % (wer * 100)) if model.ctc_loss_weight > 0: wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp_ctc.split('_'), normalize=True) print('WER (CTC): %.3f %%' % (wer_ctc * 100)) except: print('--- skipped ---') if is_new_epoch: break
def do_eval_cer(save_paths, dataset, data_type, label_type, num_classes, beam_width, temperature_infer, is_test=False, progressbar=False): """Evaluate trained model by Character Error Rate. Args: save_paths (list): dataset: An instance of a `Dataset` class data_type (string): label_type (string): character num_classes (int): beam_width (int): the size of beam temperature (int): temperature in the inference stage is_test (bool, optional): set to True when evaluating by the test set progressbar (bool, optional): if True, visualize the progressbar Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') char2idx = Char2idx( map_file_path='../metrics/mapping_files/character.txt') else: raise TypeError # Define decoder decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0], blank_index=num_classes - 1) ################################################## # Compute mean probabilities ################################################## if progressbar: pbar = tqdm(total=len(dataset)) cer_mean, wer_mean = 0, 0 for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data batch_size = inputs[0].shape[0] for i_batch in range(batch_size): probs_ensemble = None for i_model in range(len(save_paths)): # Load posteriors speaker = input_names[0][i_batch].split('-')[0] prob_save_path = join( save_paths[i_model], 'temp' + str(temperature_infer), data_type, 'probs_utt', speaker, input_names[0][i_batch] + '.npy') probs_model_i = np.load(prob_save_path) # NOTE: probs_model_i: `[T, num_classes]` # Sum over probs if probs_ensemble is None: probs_ensemble = probs_model_i else: probs_ensemble += probs_model_i # Compute mean posteriors probs_ensemble /= len(save_paths) # Decode per utterance labels_pred, scores = decoder( probs=probs_ensemble[np.newaxis, :, :], seq_len=inputs_seq_len[0][i_batch: i_batch + 1], beam_width=beam_width) # Convert from list of index to string if is_test: str_true = labels_true[0][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[0][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[0]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_pred.split('_'), hyp=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_true.split('_'), # hyp=str_pred.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= (len(dataset)) wer_mean /= (len(dataset)) # TODO: Fix this return cer_mean, wer_mean
def eval_char(models, eval_batch_size, dataset, beam_width, max_decode_len, min_decode_len=0, length_penalty=0, coverage_penalty=0, progressbar=False): """Evaluate trained model by Character Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam max_decode_len (int): the maximum sequence length to emit min_decode_len (int): the minimum sequence length to emit length_penalty (float): length penalty in beam search decoding coverage_penalty (float): coverage penalty in beam search decoding progressbar (bool): if True, visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() model = models[0] # TODO: fix this wer, cer = 0, 0 sub_word, ins_word, del_word = 0, 0, 0 sub_char, ins_char, del_char = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode if model.model_type in ['ctc', 'attention']: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty, task_index=0) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty, task_index=1) ys = batch['ys_sub'][perm_idx] y_lens = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = dataset.idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = dataset.idx2char(best_hyps[b]) str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) try: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub_word /= num_words ins_word /= num_words del_word /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars df_word = pd.DataFrame( { 'SUB': [sub_word * 100, sub_char * 100], 'INS': [ins_word * 100, ins_char * 100], 'DEL': [del_word * 100, del_char * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER', 'CER']) return wer, cer, df_word
def main(): args = parser.parse_args() # Load a config file (.yml) params = load_config(join(args.model_path, 'config.yml'), is_eval=True) # Load dataset dataset = Dataset( data_save_path=args.data_save_path, backend=params['backend'], input_freq=params['input_freq'], use_delta=params['use_delta'], use_double_delta=params['use_double_delta'], data_type='eval1', # data_type='eval2', # data_type='eval3', data_size=params['data_size'], label_type=params['label_type'], label_type_sub=params['label_type_sub'], batch_size=args.eval_batch_size, splice=params['splice'], num_stack=params['num_stack'], num_skip=params['num_skip'], sort_utt=False, reverse=False, tool=params['tool']) params['num_classes'] = dataset.num_classes params['num_classes_sub'] = dataset.num_classes_sub # Load model model = load(model_type=params['model_type'], params=params, backend=params['backend']) # Restore the saved parameters model.load_checkpoint(save_path=args.model_path, epoch=args.epoch) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') ###################################################################### word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub) for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, beam_width_sub=args.beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, max_decode_len=MAX_DECODE_LEN_WORD, min_decode_len=MIN_DECODE_LEN_WORD, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, min_decode_len=MIN_DECODE_LEN_CHAR, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty, task_index=1) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, max_decode_len=MAX_DECODE_LEN_WORD, min_decode_len=MIN_DECODE_LEN_WORD, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty, joint_decoding=args.joint_decoding, space_index=dataset.char2idx('_')[0], oov_index=dataset.word2idx('OOV')[0], word2char=word2char, idx2word=dataset.idx2word, idx2char=dataset.idx2char, score_sub_weight=args.score_sub_weight) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = dataset.idx2word(ys[b][:y_lens[b]]) str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = dataset.idx2word(best_hyps[b]) str_hyp_sub = dataset.idx2char(best_hyps_sub[b]) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: str_hyp_joint = dataset.idx2word(best_hyps_joint[b]) str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b]) ############################## # Resolving UNK ############################## if 'OOV' in str_hyp and args.resolving_unk: str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: if 'OOV' in str_hyp_joint and args.resolving_unk: str_hyp_no_unk_joint = resolve_unk(str_hyp_joint, best_hyps_sub_joint[b], aw_joint[b], aw_sub_joint[b], dataset.idx2char) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref.replace('_', ' ')) print('Hyp (main) : %s' % str_hyp.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub.replace('_', ' ')) if 'OOV' in str_hyp and args.resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' ')) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: print('===== joint decoding =====') print('Hyp (main) : %s' % str_hyp_joint.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub_joint.replace('_', ' ')) if 'OOV' in str_hyp_joint and args.resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk_joint.replace('_', ' ')) try: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp).split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer * 100)) if dataset.label_type_sub == 'character_wb': wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'), hyp=re.sub( r'(.*)>(.*)', r'\1', str_hyp_sub).split('_'), normalize=True) print('WER (sub) : %.3f %%' % (wer_sub * 100)) else: cer, _, _, _ = compute_wer( ref=list(str_ref_sub.replace('_', '')), hyp=list( re.sub(r'(.*)>(.*)', r'\1', str_hyp_sub).replace('_', '')), normalize=True) print('CER (sub) : %.3f %%' % (cer * 100)) if 'OOV' in str_hyp and args.resolving_unk: wer_no_unk, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_no_unk.replace('*', '')).split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk * 100)) if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None: print('===== joint decoding =====') wer_joint, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_joint).split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer_joint * 100)) if 'OOV' in str_hyp_joint and args.resolving_unk: wer_no_unk_joint, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=re.sub(r'(.*)_>(.*)', r'\1', str_hyp_no_unk_joint.replace( '*', '')).split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk_joint * 100)) except: print('--- skipped ---') print('\n') if is_new_epoch: break
def do_eval_per(session, decode_op, per_op, model, dataset, label_type, eval_batch_size=None, progressbar=False, is_multitask=False, is_jointctcatt=False): """Evaluate trained model by Phone Error Rate. Args: session: session of training model decode_op: operation for decoding per_op: operation for computing phone error rate model: the model to evaluate dataset: An instance of a `Dataset' class label_type (string): phone39 or phone48 or phone61 eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model is_jointctcatt (bool, optional): if True, evaluate the joint CTC-Attention model Returns: per_mean (float): An average of PER """ # Reset data counter dataset.reset() train_label_type = label_type eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type # phone2idx_39_map_file_path = '../metrics/mapping_files/attention/phone39.txt' idx2phone_train = Idx2phone( map_file_path='../metrics/mapping_files/attention/' + train_label_type + '.txt') idx2phone_eval = Idx2phone( map_file_path='../metrics/mapping_files/attention/' + eval_label_type + '.txt') map2phone39_train = Map2phone39( label_type=train_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') map2phone39_eval = Map2phone39( label_type=eval_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') per_mean = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini-batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _, _ = data elif is_jointctcatt: inputs, labels_true, _, inputs_seq_len, _, _ = data else: inputs, labels_true, inputs_seq_len, _, _ = data feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } batch_size_each = len(inputs_seq_len) # Evaluate by 39 phones labels_pred = session.run(decode_op, feed_dict=feed_dict) for i_batch in range(batch_size_each): ############### # Hypothesis ############### # Convert from num to phone (-> list of phone strings) phone_pred_list = idx2phone_train(labels_pred[i_batch]).split(' ') # Mapping to 39 phones (-> list of phone strings) phone_pred_list = map2phone39_train(phone_pred_list) ############### # Reference ############### # Convert from num to phone (-> list of phone strings) phone_true_list = idx2phone_eval(labels_true[i_batch]).split(' ') # Mapping to 39 phones (-> list of phone strings) phone_true_list = map2phone39_eval(phone_true_list) # Compute PER per_mean += compute_wer(str_pred=' '.join(phone_pred_list), str_true=' '.join(phone_true_list), normalize=True, space_mark=' ') if progressbar: pbar.update(1) if is_new_epoch: break per_mean /= len(dataset) return per_mean
def check(self, encoder_type, decoder_type, bidirectional=False, attention_type='location', subsample=False, projection=False, ctc_loss_weight_sub=0, conv=False, batch_norm=False, residual=False, dense_residual=False, num_heads=1, backward_sub=False): print('==================================================') print(' encoder_type: %s' % encoder_type) print(' bidirectional: %s' % str(bidirectional)) print(' projection: %s' % str(projection)) print(' decoder_type: %s' % decoder_type) print(' attention_type: %s' % attention_type) print(' subsample: %s' % str(subsample)) print(' ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub)) print(' conv: %s' % str(conv)) print(' batch_norm: %s' % str(batch_norm)) print(' residual: %s' % str(residual)) print(' dense_residual: %s' % str(dense_residual)) print(' backward_sub: %s' % str(backward_sub)) print(' num_heads: %s' % str(num_heads)) print('==================================================') if conv or encoder_type == 'cnn': # pattern 1 # conv_channels = [32, 32] # conv_kernel_sizes = [[41, 11], [21, 11]] # conv_strides = [[2, 2], [2, 1]] # poolings = [[], []] # pattern 2 (VGG like) conv_channels = [64, 64] conv_kernel_sizes = [[3, 3], [3, 3]] conv_strides = [[1, 1], [1, 1]] poolings = [[2, 2], [2, 2]] else: conv_channels = [] conv_kernel_sizes = [] conv_strides = [] poolings = [] # Load batch data splice = 1 num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2 xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data( label_type='word_char', batch_size=2, num_stack=num_stack, splice=splice, backend='chainer') num_classes = 11 num_classes_sub = 27 # Load model model = HierarchicalAttentionSeq2seq( input_size=xs[0].shape[-1] // splice // num_stack, # 120 encoder_type=encoder_type, encoder_bidirectional=bidirectional, encoder_num_units=320, encoder_num_proj=320 if projection else 0, encoder_num_layers=2, encoder_num_layers_sub=1, attention_type=attention_type, attention_dim=128, decoder_type=decoder_type, decoder_num_units=320, decoder_num_layers=1, decoder_num_units_sub=320, decoder_num_layers_sub=1, embedding_dim=64, embedding_dim_sub=32, dropout_input=0.1, dropout_encoder=0.1, dropout_decoder=0.1, dropout_embedding=0.1, main_loss_weight=0.8, sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0, num_classes=num_classes, num_classes_sub=num_classes_sub, parameter_init_distribution='uniform', parameter_init=0.1, recurrent_weight_orthogonal=False, init_forget_gate_bias_with_one=True, subsample_list=[] if not subsample else [True, False], subsample_type='drop' if not subsample else subsample, bridge_layer=True, init_dec_state='first', sharpening_factor=1, logits_temperature=1, sigmoid_smoothing=False, ctc_loss_weight_sub=ctc_loss_weight_sub, attention_conv_num_channels=10, attention_conv_width=201, input_channel=3, num_stack=num_stack, splice=splice, conv_channels=conv_channels, conv_kernel_sizes=conv_kernel_sizes, conv_strides=conv_strides, poolings=poolings, activation='relu', batch_norm=batch_norm, scheduled_sampling_prob=0.1, scheduled_sampling_max_step=200, label_smoothing_prob=0.1, weight_noise_std=0, encoder_residual=residual, encoder_dense_residual=dense_residual, decoder_residual=residual, decoder_dense_residual=dense_residual, decoding_order='attend_generate_update', bottleneck_dim=256, bottleneck_dim_sub=256, backward_sub=backward_sub, num_heads=num_heads, num_heads_sub=num_heads) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] print("%s %d" % (name, num_params)) print("Total %.3f M parameters" % (model.total_parameters / 1000000)) # Define optimizer learning_rate = 1e-3 model.set_optimizer('adam', learning_rate_init=learning_rate, weight_decay=1e-6, lr_schedule=False, factor=0.1, patience_epoch=5) # Define learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, backend='chainer', decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # Train model max_step = 300 start_time_step = time.time() for step in range(max_step): # Step for parameter update loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub, y_lens_sub) model.optimizer.target.cleargrads() model.cleargrads() loss.backward() loss.unchain_backward() model.optimizer.update() if (step + 1) % 10 == 0: # Compute loss loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub, y_lens_sub, is_eval=True) # Decode best_hyps, _, _ = model.decode( xs, x_lens, beam_width=1, # beam_width=2, max_decode_len=30) best_hyps_sub, _, _ = model.decode( xs, x_lens, beam_width=1, # beam_width=2, max_decode_len=60, task_index=1) str_hyp = idx2word(best_hyps[0][:-1]).split('>')[0] str_ref = idx2word(ys[0]) str_hyp_sub = idx2char(best_hyps_sub[0][:-1]).split('>')[0] str_ref_sub = idx2char(ys_sub[0]) # Compute accuracy try: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) cer, _, _, _ = compute_wer( ref=list(str_ref_sub.replace('_', '')), hyp=list(str_hyp_sub.replace('_', '')), normalize=True) except: wer = 1 cer = 1 duration_step = time.time() - start_time_step print( 'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, loss_main, loss_sub, wer, cer, learning_rate, duration_step)) start_time_step = time.time() # Visualize print('Ref: %s' % str_ref) print('Hyp (word): %s' % str_hyp) print('Hyp (char): %s' % str_hyp_sub) if cer < 0.1: print('Modle is Converged.') break # Update learning rate model.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.optimizer, learning_rate=learning_rate, epoch=step, value=wer)
def eval_char(models, dataset, beam_width, max_decode_len, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1): """Evaluate trained model by Character Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar length_penalty (float, optional): temperature (int, optional): Returns: wer (float): Word error rate cer (float): Character error rate df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() if models[0].model_type in ['ctc', 'attention']: idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path) else: idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub) wer, cer = 0, 0 sub_word, ins_word, del_word = 0, 0, 0 sub_char, ins_char, del_char = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode if len(models) > 1: assert models[0].model_type in ['ctc'] for i, model in enumerate(models): probs, x_lens, perm_idx = model.posteriors( batch['xs'], batch['x_lens'], temperature=temperature) if i == 0: probs_ensenmble = probs else: probs_ensenmble += probs probs_ensenmble /= len(models) best_hyps = models[0].decode_from_probs( probs_ensenmble, x_lens, beam_width=beam_width) else: model = models[0] # TODO: fix this best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty, task_index=0 if model.model_type in ['ctc', 'attention'] else 1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2char(best_hyps[b]) str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) try: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub_word /= num_words ins_word /= num_words del_word /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars df_wer_cer = pd.DataFrame( {'SUB': [sub_word * 100, sub_char * 100], 'INS': [ins_word * 100, ins_char * 100], 'DEL': [del_word * 100, del_char * 100]}, columns=['SUB', 'INS', 'DEL'], index=['WER', 'CER']) return wer, cer, df_wer_cer
def check(self, usage_dec_sub='all', att_reg_weight=1, main_loss_weight=0.5, ctc_loss_weight_sub=0, dec_attend_temperature=1, dec_sigmoid_smoothing=False, backward_sub=False, num_heads=1, second_pass=False, relax_context_vec_dec=False): print('==================================================') print(' usage_dec_sub: %s' % usage_dec_sub) print(' att_reg_weight: %s' % str(att_reg_weight)) print(' main_loss_weight: %s' % str(main_loss_weight)) print(' ctc_loss_weight_sub: %s' % str(ctc_loss_weight_sub)) print(' dec_attend_temperature: %s' % str(dec_attend_temperature)) print(' dec_sigmoid_smoothing: %s' % str(dec_sigmoid_smoothing)) print(' backward_sub: %s' % str(backward_sub)) print(' num_heads: %s' % str(num_heads)) print(' second_pass: %s' % str(second_pass)) print(' relax_context_vec_dec: %s' % str(relax_context_vec_dec)) print('==================================================') # Load batch data splice = 1 num_stack = 1 xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data( label_type='word_char', batch_size=2, num_stack=num_stack, splice=splice) # Load model model = NestedAttentionSeq2seq( input_size=xs.shape[-1] // splice // num_stack, # 120 encoder_type='lstm', encoder_bidirectional=True, encoder_num_units=256, encoder_num_proj=0, encoder_num_layers=2, encoder_num_layers_sub=2, attention_type='location', attention_dim=128, decoder_type='lstm', decoder_num_units=256, decoder_num_layers=1, decoder_num_units_sub=256, decoder_num_layers_sub=1, embedding_dim=64, embedding_dim_sub=32, dropout_input=0.1, dropout_encoder=0.1, dropout_decoder=0.1, dropout_embedding=0.1, main_loss_weight=0.8, sub_loss_weight=0.2 if ctc_loss_weight_sub == 0 else 0, num_classes=11, num_classes_sub=27 if not second_pass else 11, parameter_init_distribution='uniform', parameter_init=0.1, recurrent_weight_orthogonal=False, init_forget_gate_bias_with_one=True, subsample_list=[True, False], subsample_type='drop', init_dec_state='first', sharpening_factor=1, logits_temperature=1, sigmoid_smoothing=False, ctc_loss_weight_sub=ctc_loss_weight_sub, attention_conv_num_channels=10, attention_conv_width=201, num_stack=num_stack, splice=1, conv_channels=[], conv_kernel_sizes=[], conv_strides=[], poolings=[], batch_norm=False, scheduled_sampling_prob=0.1, scheduled_sampling_max_step=200, label_smoothing_prob=0.1, weight_noise_std=0, encoder_residual=False, encoder_dense_residual=False, decoder_residual=False, decoder_dense_residual=False, decoding_order='attend_generate_update', # decoding_order='attend_update_generate', # decoding_order='conditional', bottleneck_dim=256, bottleneck_dim_sub=256, backward_sub=backward_sub, num_heads=num_heads, num_heads_sub=num_heads, num_heads_dec=num_heads, usage_dec_sub=usage_dec_sub, att_reg_weight=att_reg_weight, dec_attend_temperature=dec_attend_temperature, dec_sigmoid_smoothing=dec_attend_temperature, relax_context_vec_dec=relax_context_vec_dec, dec_attention_type='location') # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] print("%s %d" % (name, num_params)) print("Total %.3f M parameters" % (model.total_parameters / 1000000)) # Define optimizer learning_rate = 1e-3 model.set_optimizer('adam', learning_rate_init=learning_rate, weight_decay=1e-6, lr_schedule=False, factor=0.1, patience_epoch=5) # Define learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, backend='pytorch', decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # Train model max_step = 300 start_time_step = time.time() for step in range(max_step): # Step for parameter update model.optimizer.zero_grad() if second_pass: loss = model(xs, ys, x_lens, y_lens) else: loss, loss_main, loss_sub = model( xs, ys, x_lens, y_lens, ys_sub, y_lens_sub) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5) model.optimizer.step() if (step + 1) % 10 == 0: # Compute loss if second_pass: loss = model(xs, ys, x_lens, y_lens, is_eval=True) else: loss, loss_main, loss_sub = model( xs, ys, x_lens, y_lens, ys_sub, y_lens_sub, is_eval=True) best_hyps, _, best_hyps_sub, _, perm_idx = model.decode( xs, x_lens, beam_width=1, max_decode_len=30, max_decode_len_sub=60) str_hyp = idx2word(best_hyps[0][:-1]) str_ref = idx2word(ys[0]) if second_pass: str_hyp_sub = idx2word(best_hyps_sub[0][:-1]) str_ref_sub = idx2word(ys[0]) else: str_hyp_sub = idx2char(best_hyps_sub[0][:-1]) str_ref_sub = idx2char(ys_sub[0]) # Compute accuracy try: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) if second_pass: cer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp_sub.split('_'), normalize=True) else: cer, _, _, _ = compute_wer( ref=list(str_ref_sub.replace('_', '')), hyp=list(str_hyp_sub.replace('_', '')), normalize=True) except: wer = 1 cer = 1 duration_step = time.time() - start_time_step if second_pass: print('Step %d: loss=%.3f / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, wer, cer, learning_rate, duration_step)) else: print('Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, loss_main, loss_sub, wer, cer, learning_rate, duration_step)) start_time_step = time.time() # Visualize print('Ref: %s' % str_ref) print('Hyp (word): %s' % str_hyp) print('Hyp (char): %s' % str_hyp_sub) if cer < 0.1: print('Modle is Converged.') break # Update learning rate model.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.optimizer, learning_rate=learning_rate, epoch=step, value=wer)
def decode(model, dataset, eval_batch_size, beam_width, length_penalty, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam length_penalty (float): save_path (string): path to save decoding results """ idx2phone = Idx2phone(dataset.vocab_file_path) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_PHONE, length_penalty=length_penalty) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space(' ') else: # Convert from list of index to string str_ref = idx2phone(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2phone(best_hyps[b]) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref) print('Hyp : %s' % str_hyp) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = idx2phone(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) # Compute PER per, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=re.sub(r'(.*) >(.*)', r'\1', str_hyp).split(' '), normalize=True) print('PER: %.3f %%' % (per * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=str_hyp_ctc.split(' '), normalize=True) print('PER (CTC): %.3f %%' % (per_ctc * 100)) if is_new_epoch: break
def do_eval_wer(session, decode_ops, model, dataset, train_data_size, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Word Error Rate. Args: session: session of training model decode_ops: list of operations for decoding model: the model to evaluate dataset: An instance of `Dataset` class train_data_size (string): train100h or train460h or train960h is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: wer_mean (bool): An average of WER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word( map_file_path='../metrics/mapping_files/word_' + train_data_size + '.txt') wer_mean = 0 skip_data_num = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, labels_true, _, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device] ] = inputs_seq_len[i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(labels_pred_st_list): batch_size_device = len(inputs[i_device]) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size_device) for i_batch in range(batch_size_device): if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = '_'.join( idx2word(labels_true[i_device][i_batch])) str_pred = '_'.join(idx2word(labels_pred[i_batch])) # if len(str_true.split('_')) == 0: # print(str_true) # print(str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_true.split(' '), # hyp=str_pred.split(' ')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) if progressbar: pbar.update(1) except IndexError: print('skipped') skip_data_num += batch_size_device # TODO: Conduct decoding again with batch size 1 if progressbar: pbar.update(batch_size_device) if is_new_epoch: break wer_mean /= (len(dataset) - skip_data_num) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return wer_mean
def do_eval_cer(session, decode_ops, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_ops: list of operations for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character_capital_divide.txt', capital_divide=True, space_mark='_') else: raise TypeError cer_mean, wer_mean = 0, 0 skip_data_num = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device] ] = inputs_seq_len[i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(labels_pred_st_list): batch_size_device = len(inputs[i_device]) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size_device) for i_batch in range(batch_size_device): # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[i_batch]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) except IndexError: print('skipped') skip_data_num += batch_size_device # TODO: Conduct decoding again with batch size 1 if progressbar: pbar.update(batch_size_device) if is_new_epoch: break cer_mean /= (len(dataset) - skip_data_num) wer_mean /= (len(dataset) - skip_data_num) # TODO: Fix this # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def do_eval_cer2(session, posteriors_ops, beam_width, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model posteriors_ops: list of operations for computing posteriors beam_width (int): model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ assert isinstance(posteriors_ops, list), "posteriors_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') char2idx = Char2idx( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': raise NotImplementedError else: raise TypeError # Define decoder decoder = BeamSearchDecoder(space_index=char2idx('_')[0], blank_index=model.num_classes - 1) cer_mean, wer_mean = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(posteriors_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device] ] = inputs_seq_len[i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(posteriors_list): batch_size_device, max_time = inputs[i_device].shape[:2] posteriors = posteriors_list[i_device].reshape( batch_size_device, max_time, model.num_classes) for i_batch in range(batch_size_device): # Decode per utterance labels_pred, scores = decoder( probs=posteriors[i_batch:i_batch + 1], seq_len=inputs_seq_len[i_device][i_batch: i_batch + 1], beam_width=beam_width) # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[0]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= (len(dataset)) wer_mean /= (len(dataset)) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def do_eval_cer(session, decode_op, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False, is_jointctcatt=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model is_jointctcatt (bool, optional): if True, evaluate the joint CTC-Attention model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2char = Idx2char( map_file_path='../metrics/mapping_files/' + label_type + '.txt') cer_mean, wer_mean = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini-batch if is_multitask: inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data elif is_jointctcatt: inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_encoder_pl_list[0]: 1.0, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred = session.run(decode_op, feed_dict=feed_dict) for i_batch in range(batch_size): # Convert from list of index to string if is_test: str_true = labels_true[0][i_batch][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_true = idx2char( labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[i_batch]).split('>')[0] # NOTE: Trancate by <EOS> # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[<>\'\":;!?,.-]+', '', str_true) str_pred = re.sub(r'[<>\'\":;!?,.-]+', '', str_pred) # Compute WER wer_mean += compute_wer(hyp=str_pred.split('_'), ref=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_pred = re.sub(r'[_]+', '', str_pred) str_true = re.sub(r'[_]+', '', str_true) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= len(dataset) wer_mean /= len(dataset) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def decode(model, dataset, beam_width, max_decode_len, max_decode_len_sub, eval_batch_size=None, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. max_decode_len_sub (int) eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) if dataset.label_type_sub == 'character': idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub) elif dataset.label_type_sub == 'character_capital_divide': idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub, capital_divide=True) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'charseq_attention': best_hyps, best_hyps_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=100) else: best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) best_hyps_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, is_sub_task=True) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) if model.model_type != 'hierarchical_ctc': str_hyp = str_hyp.split('>')[0] str_hyp_sub = str_hyp_sub.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] if len(str_hyp_sub) > 0 and str_hyp_sub[-1] == '_': str_hyp_sub = str_hyp_sub[:-1] ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp (main): %s' % str_hyp.replace('_', ' ')) # print('Ref (sub): %s' % str_ref_sub.replace('_', ' ')) print('Hyp (sub): %s' % str_hyp_sub.replace('_', ' ')) wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) print('WER: %f %%' % (wer * 100)) cer, _, _, _ = compute_wer(ref=list(str_ref_sub.replace('_', '')), hyp=list(str_hyp_sub.replace('_', '')), normalize=True) print('CER: %f %%' % (cer * 100)) if is_new_epoch: break
def do_eval_cer2(session, posteriors_ops, beam_width, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model posteriors_ops: list of operations for computing posteriors beam_width (int): model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ assert isinstance(posteriors_ops, list), "posteriors_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') char2idx = Char2idx( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': raise NotImplementedError else: raise TypeError # Define decoder decoder = BeamSearchDecoder(space_index=char2idx('_')[0], blank_index=model.num_classes - 1) cer_mean, wer_mean = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(posteriors_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[ i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 posteriors_list = session.run(posteriors_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(posteriors_list): batch_size_device, max_time = inputs[i_device].shape[:2] posteriors = posteriors_list[i_device].reshape( batch_size_device, max_time, model.num_classes) for i_batch in range(batch_size_device): # Decode per utterance labels_pred, scores = decoder( probs=posteriors[i_batch:i_batch + 1], seq_len=inputs_seq_len[i_device][i_batch:i_batch + 1], beam_width=beam_width) # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[0]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= (len(dataset)) wer_mean /= (len(dataset)) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def main(): args = parser.parse_args() # Load a config file (.yml) params = load_config(join(args.model_path, 'config.yml'), is_eval=True) # Load dataset dataset = Dataset( data_save_path=args.data_save_path, backend=params['backend'], input_freq=params['input_freq'], use_delta=params['use_delta'], use_double_delta=params['use_double_delta'], data_type='test', label_type=params['label_type'], batch_size=args.eval_batch_size, splice=params['splice'], num_stack=params['num_stack'], num_skip=params['num_skip'], sort_utt=True, reverse=True, tool=params['tool']) params['num_classes'] = dataset.num_classes # Load model model = load(model_type=params['model_type'], params=params, backend=params['backend']) # Restore the saved parameters model.load_checkpoint(save_path=args.model_path, epoch=args.epoch) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') ###################################################################### for batch, is_new_epoch in dataset: # Decode best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=args.beam_width, max_decode_len=MAX_DECODE_LEN_PHONE, min_decode_len=MIN_DECODE_LEN_PHONE, length_penalty=args.length_penalty, coverage_penalty=args.coverage_penalty) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc( batch['xs'], batch['x_lens'], beam_width=args.beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space(' ') else: # Convert from list of index to string str_ref = dataset.idx2phone(ys[b][: y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = dataset.idx2phone(best_hyps[b]) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref) print('Hyp : %s' % str_hyp) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = dataset.idx2phone(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) # Compute PER per, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=re.sub(r'(.*) >(.*)', r'\1', str_hyp).split(' '), normalize=True) print('PER: %.3f %%' % (per * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=str_hyp_ctc.split(' '), normalize=True) print('PER (CTC): %.3f %%' % (per_ctc * 100)) if is_new_epoch: break
def do_eval_cer(session, decode_ops, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_ops: list of operations for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': idx2char = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True, space_mark='_') else: raise TypeError cer_mean, wer_mean = 0, 0 skip_data_num = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[ i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(labels_pred_st_list): batch_size_device = len(inputs[i_device]) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size_device) for i_batch in range(batch_size_device): # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[i_batch]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) except IndexError: print('skipped') skip_data_num += batch_size_device # TODO: Conduct decoding again with batch size 1 if progressbar: pbar.update(batch_size_device) if is_new_epoch: break cer_mean /= (len(dataset) - skip_data_num) wer_mean /= (len(dataset) - skip_data_num) # TODO: Fix this # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def decode(model, dataset, beam_width, beam_width_sub, eval_batch_size=None, save_path=None, resolving_unk=False): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam in the main task beam_width: (int): the size of beam in the sub task eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results resolving_unk (bool, optional): """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') # Read GLM file glm = GLM( glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_WORD) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref_original = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref_original = idx2word(ys[b][: y_lens[b]]) str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) ############################## # Resolving UNK ############################## if 'OOV' in str_hyp and resolving_unk: str_hyp_no_unk = resolve_unk( str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) # if 'OOV' not in str_hyp: # continue ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref_original, glm) str_ref_sub = fix_trans(str_ref_sub, glm) str_hyp = fix_trans(str_hyp, glm) str_hyp_sub = fix_trans(str_hyp_sub, glm) str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm) if len(str_ref) == 0: continue print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref.replace('_', ' ')) print('Hyp (main) : %s' % str_hyp.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub.replace('_', ' ')) if 'OOV' in str_hyp and resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' ')) try: # Compute WER wer, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.replace(r'_>.*', '').split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer * 100)) wer_sub, _, _, _ = compute_wer( ref=str_ref_sub.split('_'), hyp=str_hyp_sub.replace(r'>.*', '').split('_'), normalize=True) print('WER (sub) : %.3f %%' % (wer_sub * 100)) if 'OOV' in str_hyp and resolving_unk: wer_no_unk, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp_no_unk.replace( '*', '').replace(r'_>.*', '').split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk * 100)) except: print('--- skipped ---') if is_new_epoch: break
def do_eval_wer(session, decode_ops, model, dataset, train_data_size, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Word Error Rate. Args: session: session of training model decode_ops: list of operations for decoding model: the model to evaluate dataset: An instance of `Dataset` class train_data_size (string): train100h or train460h or train960h is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: wer_mean (bool): An average of WER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(map_file_path='../metrics/mapping_files/word_' + train_data_size + '.txt') wer_mean = 0 skip_data_num = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, labels_true, _, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[ i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(labels_pred_st_list): batch_size_device = len(inputs[i_device]) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size_device) for i_batch in range(batch_size_device): if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = '_'.join( idx2word(labels_true[i_device][i_batch])) str_pred = '_'.join(idx2word(labels_pred[i_batch])) # if len(str_true.split('_')) == 0: # print(str_true) # print(str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_true.split(' '), # hyp=str_pred.split(' ')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) if progressbar: pbar.update(1) except IndexError: print('skipped') skip_data_num += batch_size_device # TODO: Conduct decoding again with batch size 1 if progressbar: pbar.update(batch_size_device) if is_new_epoch: break wer_mean /= (len(dataset) - skip_data_num) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return wer_mean
def check(self, encoder_type, decoder_type, bidirectional=False, attention_type='location', label_type='char', subsample=False, projection=False, init_dec_state='first', ctc_loss_weight=0, conv=False, batch_norm=False, residual=False, dense_residual=False, decoding_order='bahdanau_attention', backward_loss_weight=0, num_heads=1, beam_width=1): print('==================================================') print(' label_type: %s' % label_type) print(' encoder_type: %s' % encoder_type) print(' bidirectional: %s' % str(bidirectional)) print(' projection: %s' % str(projection)) print(' decoder_type: %s' % decoder_type) print(' init_dec_state: %s' % init_dec_state) print(' attention_type: %s' % attention_type) print(' subsample: %s' % str(subsample)) print(' ctc_loss_weight: %s' % str(ctc_loss_weight)) print(' conv: %s' % str(conv)) print(' batch_norm: %s' % str(batch_norm)) print(' residual: %s' % str(residual)) print(' dense_residual: %s' % str(dense_residual)) print(' decoding_order: %s' % decoding_order) print(' backward_loss_weight: %s' % str(backward_loss_weight)) print(' num_heads: %s' % str(num_heads)) print(' beam_width: %s' % str(beam_width)) print('==================================================') if conv or encoder_type == 'cnn': # pattern 1 # conv_channels = [32, 32] # conv_kernel_sizes = [[41, 11], [21, 11]] # conv_strides = [[2, 2], [2, 1]] # poolings = [[], []] # pattern 2 (VGG like) conv_channels = [64, 64] conv_kernel_sizes = [[3, 3], [3, 3]] conv_strides = [[1, 1], [1, 1]] poolings = [[2, 2], [2, 2]] else: conv_channels = [] conv_kernel_sizes = [] conv_strides = [] poolings = [] # Load batch data splice = 1 num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 3 xs, ys, x_lens, y_lens = generate_data(label_type=label_type, batch_size=2, num_stack=num_stack, splice=splice) if label_type == 'char': num_classes = 27 map_fn = idx2char elif label_type == 'word': num_classes = 11 map_fn = idx2word # Load model model = AttentionSeq2seq( input_size=xs.shape[-1] // splice // num_stack, # 120 encoder_type=encoder_type, encoder_bidirectional=bidirectional, encoder_num_units=256, encoder_num_proj=256 if projection else 0, encoder_num_layers=1 if not subsample else 2, attention_type=attention_type, attention_dim=128, decoder_type=decoder_type, decoder_num_units=256, decoder_num_layers=1, embedding_dim=32, dropout_input=0.1, dropout_encoder=0.1, dropout_decoder=0.1, dropout_embedding=0.1, num_classes=num_classes, parameter_init_distribution='uniform', parameter_init=0.1, recurrent_weight_orthogonal=False, init_forget_gate_bias_with_one=True, subsample_list=[] if not subsample else [True, False], subsample_type='concat' if not subsample else subsample, bridge_layer=True, init_dec_state=init_dec_state, sharpening_factor=1, logits_temperature=1, sigmoid_smoothing=False, coverage_weight=0, ctc_loss_weight=ctc_loss_weight, attention_conv_num_channels=10, attention_conv_width=201, num_stack=num_stack, splice=splice, input_channel=3, conv_channels=conv_channels, conv_kernel_sizes=conv_kernel_sizes, conv_strides=conv_strides, poolings=poolings, activation='relu', batch_norm=batch_norm, scheduled_sampling_prob=0.1, scheduled_sampling_max_step=200, label_smoothing_prob=0.1, weight_noise_std=1e-9, encoder_residual=residual, encoder_dense_residual=dense_residual, decoder_residual=residual, decoder_dense_residual=dense_residual, decoding_order=decoding_order, bottleneck_dim=256, backward_loss_weight=backward_loss_weight, num_heads=num_heads) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] print("%s %d" % (name, num_params)) print("Total %.3f M parameters" % (model.total_parameters / 1000000)) # Define optimizer learning_rate = 1e-3 model.set_optimizer('adam', learning_rate_init=learning_rate, weight_decay=1e-8, lr_schedule=False, factor=0.1, patience_epoch=5) # Define learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, backend='pytorch', decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # Train model max_step = 300 start_time_step = time.time() for step in range(max_step): # Step for parameter update model.optimizer.zero_grad() loss = model(xs, ys, x_lens, y_lens) loss.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), 5) torch.nn.utils.clip_grad_norm(model.parameters(), 5) model.optimizer.step() # Inject Gaussian noise to all parameters # if loss.item() < 50: if loss.data[0] < 50: model.weight_noise_injection = True if (step + 1) % 10 == 0: # Compute loss loss = model(xs, ys, x_lens, y_lens, is_eval=True) # Decode best_hyps, _, perm_idx = model.decode(xs, x_lens, beam_width, max_decode_len=60) str_ref = map_fn(ys[0]) str_hyp = map_fn(best_hyps[0][:-1]) # Compute accuracy try: if label_type == 'char': ler, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=True) elif label_type == 'word': ler, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) except: ler = 1 duration_step = time.time() - start_time_step print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, ler, learning_rate, duration_step)) start_time_step = time.time() # Visualize print('Ref: %s' % str_ref) print('Hyp: %s' % str_hyp) # Decode by the CTC decoder if model.ctc_loss_weight >= 0.1: best_hyps_ctc, perm_idx = model.decode_ctc(xs, x_lens, beam_width=1) str_pred_ctc = map_fn(best_hyps_ctc[0]) print('Hyp (CTC): %s' % str_pred_ctc) if ler < 0.1: print('Modle is Converged.') break # Update learning rate model.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.optimizer, learning_rate=learning_rate, epoch=step, value=ler)
def eval_word(models, dataset, eval_batch_size, beam_width, max_decode_len, min_decode_len=0, beam_width_sub=1, max_decode_len_sub=200, min_decode_len_sub=0, length_penalty=0, coverage_penalty=0, progressbar=False, resolving_unk=False, a2c_oracle=False, joint_decoding=None, score_sub_weight=0): """Evaluate trained model by Word Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class eval_batch_size (int): the batch size when evaluating the model beam_width (int): the size of beam in ths main task max_decode_len (int): the maximum sequence length of tokens in the main task min_decode_len (int): the minimum sequence length of tokens in the main task beam_width_sub (int): the size of beam in ths sub task This is used for the nested attention max_decode_len_sub (int): the maximum sequence length of tokens in the sub task min_decode_len_sub (int): the minimum sequence length of tokens in the sub task length_penalty (float): length penalty in beam search decoding coverage_penalty (float): coverage penalty in beam search decoding progressbar (bool): if True, visualize the progressbar resolving_unk (bool): a2c_oracle (bool): joint_decoding (bool): onepass or resocring or None score_sub_weight (float): Returns: wer (float): Word error rate df_word (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() model = models[0] # TODO: fix this if model.model_type == 'hierarchical_attention' and joint_decoding is not None: word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub) wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) batch_size = len(batch['xs']) # Decode if model.model_type == 'nested_attention': if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list(batch['ys_sub'][b][0])): max_label_num = len( list(batch['ys_sub'][b][0])) ys_sub = np.zeros( (batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size,), dtype=np.int32) for b in range(batch_size): indices = dataset.char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, beam_width_sub=beam_width_sub, max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub, min_decode_len_sub=min_decode_len_sub, length_penalty=length_penalty, coverage_penalty=coverage_penalty, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) elif model.model_type == 'hierarchical_attention' and joint_decoding is not None: best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty, joint_decoding=joint_decoding, space_index=dataset.char2idx('_')[0], oov_index=dataset.word2idx('OOV')[0], word2char=word2char, idx2word=dataset.idx2word, idx2char=dataset.idx2char, score_sub_weight=score_sub_weight) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, min_decode_len=min_decode_len, length_penalty=length_penalty, coverage_penalty=coverage_penalty) if resolving_unk: best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, min_decode_len=min_decode_len_sub, length_penalty=length_penalty, coverage_penalty=coverage_penalty, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(batch_size): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = dataset.idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = dataset.idx2word(best_hyps[b]) if dataset.label_type == 'word': str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp) else: str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Resolving UNK ############################## if resolving_unk and 'OOV' in str_hyp: str_hyp = resolve_unk( str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], dataset.idx2char) str_hyp = str_hyp.replace('*', '') ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) # Compute WER try: wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_word = pd.DataFrame( {'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100]}, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_word
def eval_char(models, dataset, beam_width, max_decode_len, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1): """Evaluate trained model by Character Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. length_penalty (float, optional): eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() if models[0].model_type in ['ctc', 'attention']: idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path, capital_divide=(dataset.label_type == 'character_capital_divide')) else: idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path_sub, capital_divide=( dataset.label_type_sub == 'character_capital_divide')) # Read GLM file glm = GLM( glm_path= '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm' ) wer, cer = 0, 0 sub_word, ins_word, del_word = 0, 0, 0 sub_char, ins_char, del_char = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # TODO: add CTC ensemble # Decode model = models[0] # TODO: fix this if model.model_type in ['ctc', 'attention']: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty, task_index=1) ys = batch['ys_sub'][perm_idx] y_lens = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2char(best_hyps[b]) if 'attention' in model.model_type: str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref, glm) str_hyp = fix_trans(str_hyp, glm) if len(str_ref) == 0: if progressbar: pbar.update(1) continue try: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub_word /= num_words ins_word /= num_words del_word /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars df_wer_cer = pd.DataFrame( { 'SUB': [sub_word * 100, sub_char * 100], 'INS': [ins_word * 100, ins_char * 100], 'DEL': [del_word * 100, del_char * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER', 'CER']) return wer, cer, df_wer_cer
def check(self, encoder_type, bidirectional=False, subsample=False, projection=False, conv=False, batch_norm=False, activation='relu', encoder_residual=False, encoder_dense_residual=False, label_smoothing=False): print('==================================================') print(' encoder_type: %s' % encoder_type) print(' bidirectional: %s' % str(bidirectional)) print(' projection: %s' % str(projection)) print(' subsample: %s' % str(subsample)) print(' conv: %s' % str(conv)) print(' batch_norm: %s' % str(batch_norm)) print(' encoder_residual: %s' % str(encoder_residual)) print(' encoder_dense_residual: %s' % str(encoder_dense_residual)) print(' label_smoothing: %s' % str(label_smoothing)) print('==================================================') if conv or encoder_type == 'cnn': # pattern 1 # conv_channels = [32, 32] # conv_kernel_sizes = [[41, 11], [21, 11]] # conv_strides = [[2, 2], [2, 1]] # poolings = [[], []] # pattern 2 (VGG like) conv_channels = [64, 64] conv_kernel_sizes = [[3, 3], [3, 3]] conv_strides = [[1, 1], [1, 1]] poolings = [[2, 2], [2, 2]] fc_list = [786, 786] else: conv_channels = [] conv_kernel_sizes = [] conv_strides = [] poolings = [] fc_list = [] # Load batch data num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2 splice = 1 xs, ys, ys_sub, x_lens, y_lens, y_lens_sub = generate_data( label_type='word_char', batch_size=2, num_stack=num_stack, splice=splice) num_classes = 11 num_classes_sub = 27 # Load model model = HierarchicalCTC( input_size=xs.shape[-1] // splice // num_stack, # 120 encoder_type=encoder_type, encoder_bidirectional=bidirectional, encoder_num_units=256, encoder_num_proj=256 if projection else 0, encoder_num_layers=2, encoder_num_layers_sub=1, fc_list=fc_list, fc_list_sub=fc_list, dropout_input=0.1, dropout_encoder=0.1, main_loss_weight=0.8, sub_loss_weight=0.2, num_classes=num_classes, num_classes_sub=num_classes_sub, parameter_init_distribution='uniform', parameter_init=0.1, recurrent_weight_orthogonal=False, init_forget_gate_bias_with_one=True, subsample_list=[] if not subsample else [True, False], num_stack=num_stack, splice=splice, input_channel=3, conv_channels=conv_channels, conv_kernel_sizes=conv_kernel_sizes, conv_strides=conv_strides, poolings=poolings, batch_norm=batch_norm, label_smoothing_prob=0.1 if label_smoothing else 0, weight_noise_std=0, encoder_residual=encoder_residual, encoder_dense_residual=encoder_dense_residual) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] print("%s %d" % (name, num_params)) print("Total %.3f M parameters" % (model.total_parameters / 1000000)) # Define optimizer learning_rate = 1e-3 model.set_optimizer('adam', learning_rate_init=learning_rate, weight_decay=1e-6, lr_schedule=False, factor=0.1, patience_epoch=5) # Define learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, backend='pytorch', decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # Train model max_step = 300 start_time_step = time.time() for step in range(max_step): # Step for parameter update model.optimizer.zero_grad() loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub, y_lens_sub) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5) model.optimizer.step() if (step + 1) % 10 == 0: # Compute loss loss, loss_main, loss_sub = model(xs, ys, x_lens, y_lens, ys_sub, y_lens_sub, is_eval=True) # Decode best_hyps, _, _ = model.decode(xs, x_lens, beam_width=2, task_index=0) best_hyps_sub, _, _ = model.decode(xs, x_lens, beam_width=2, task_index=1) str_ref = idx2word(ys[0, :y_lens[0]]) str_hyp = idx2word(best_hyps[0]) str_ref_sub = idx2char(ys_sub[0, :y_lens_sub[0]]) str_hyp_sub = idx2char(best_hyps_sub[0]) # Compute accuracy try: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) cer, _, _, _ = compute_wer( ref=list(str_ref_sub.replace('_', '')), hyp=list(str_hyp_sub.replace('_', '')), normalize=True) except: wer = 1 cer = 1 duration_step = time.time() - start_time_step print( 'Step %d: loss=%.3f(%.3f/%.3f) / wer=%.3f / cer=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, loss_main, loss_sub, wer, cer, learning_rate, duration_step)) start_time_step = time.time() # Visualize print('Ref: %s' % str_ref) print('Hyp (word): %s' % str_hyp) print('Hyp (char): %s' % str_hyp_sub) if cer < 0.1: print('Modle is Converged.') break # Update learning rate model.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.optimizer, learning_rate=learning_rate, epoch=step, value=wer)
def eval_word(models, dataset, beam_width, max_decode_len, beam_width_sub=1, max_decode_len_sub=300, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1, resolving_unk=False, a2c_oracle=False): """Evaluate trained model by Word Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class max_decode_len (int): the length of output sequences to stop prediction. This is used for seq2seq models. beam_width_sub (int, optional): the size of beam in ths sub task This is used for the nested attention max_decode_len_sub (int, optional): the length of output sequences to stop prediction. This is used for the nested attention eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar temperature (int, optional): resolving_unk (bool, optional): a2c_oracle (bool, optional): Returns: wer (float): Word error rate df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2word = Idx2word(dataset.vocab_file_path) if models[0].model_type == 'nested_attention': char2idx = Char2idx(dataset.vocab_file_path_sub) if models[0] in ['ctc', 'attention'] and resolving_unk: idx2char = Idx2char(dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) batch_size = len(batch['xs']) # Decode if len(models) > 1: assert models[0].model_type in ['ctc'] for i, model in enumerate(models): probs, x_lens, perm_idx = model.posteriors( batch['xs'], batch['x_lens']) if i == 0: probs_ensenmble = probs else: probs_ensenmble += probs probs_ensenmble /= len(models) best_hyps = models[0].decode_from_probs(probs_ensenmble, x_lens, beam_width=1) else: model = models[0] # TODO: fix this if model.model_type == 'nested_attention': if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list( batch['ys_sub'][b][0])): max_label_num = len(list( batch['ys_sub'][b][0])) ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size, ), dtype=np.int32) for b in range(batch_size): indices = char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=max_decode_len, max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub, length_penalty=length_penalty, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) if resolving_unk: best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, length_penalty=length_penalty, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(batch_size): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2word(best_hyps[b]) if dataset.label_type == 'word': str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp) else: str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Resolving UNK ############################## if resolving_unk and 'OOV' in str_hyp: str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) str_hyp = str_hyp.replace('*', '') ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) # Compute WER try: wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_wer = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_wer
def do_eval_cer(models, model_type, dataset, label_type, beam_width, max_decode_len, eval_batch_size=None, temperature=1, progressbar=False): """Evaluate trained models by Character Error Rate. Args: models (list): the model to evaluate model_type (string): ctc or attention or hierarchical_ctc or hierarchical_attention dataset: An instance of a `Dataset' class label_type (string): character or character_capital_divide beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. eval_batch_size (int, optional): the batch size when evaluating the model temperature (int, optional): progressbar (bool, optional): if True, visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path, capital_divide=(dataset.label_type == 'character_capital_divide')) cer, wer = 0, 0 sub_char, ins_char, del_char = 0, 0, 0 sub_word, ins_word, del_word = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode the ensemble if model_type in ['attention', 'ctc']: for i, model in enumerate(models): probs_i, perm_idx = model.posteriors(batch['xs'], batch['x_lens'], temperature=temperature) if i == 0: probs = probs_i else: probs += probs_i # NOTE: probs: `[1 (B), T, num_classes]` probs /= len(models) best_hyps = model.decode_from_probs(probs, batch['x_lens'][perm_idx], beam_width=beam_width, max_decode_len=max_decode_len) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] elif model_type in ['hierarchical_attention', 'hierarchical_ctc']: raise NotImplementedError for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2char(best_hyps[b]) if 'attention' in model.model_type: str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] # Remove consecutive spaces str_hyp = re.sub(r'[_]+', '_', str_hyp) ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) # TODO: WER計算するときに消していい? # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() wer /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars sub_word /= num_words ins_word /= num_words del_word /= num_words df_wer_cer = pd.DataFrame( { 'SUB': [sub_char * 100, sub_word * 100], 'INS': [ins_char * 100, ins_word * 100], 'DEL': [del_char * 100, del_word * 100] }, columns=['SUB', 'INS', 'DEL'], index=['CER', 'WER']) return cer, wer, df_wer_cer
def do_eval_cer(session, decode_op, model, dataset, label_type, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ # Reset data counter dataset.reset() if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/ctc/character.txt') elif label_type == 'character_capital_divide': idx2char = Idx2char( map_file_path= '../metrics/mapping_files/ctc/character_capital_divide.txt', capital_divide=True, space_mark='_') cer_mean, wer_mean = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, labels_true, _, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } batch_size_each = len(inputs) labels_pred_st = session.run(decode_op, feed_dict=feed_dict) labels_pred = sparsetensor2list(labels_pred_st, batch_size_each) for i_batch in range(batch_size_each): # Convert from list of index to string str_true = idx2char(labels_true[i_batch]) str_pred = idx2char(labels_pred[i_batch]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\'\":;!?,.-]+', "", str_true) str_pred = re.sub(r'[\'\":;!?,.-]+', "", str_pred) # Compute WER wer_mean += compute_wer(hyp=str_pred.split('_'), ref=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print(substitute) # print(insert) # print(delete) # Remove spaces str_pred = re.sub(r'[_]+', "", str_pred) str_true = re.sub(r'[_]+', "", str_true) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= len(dataset) wer_mean /= len(dataset) return cer_mean, wer_mean
def check(self, encoder_type, bidirectional=False, label_type='char', subsample=False, projection=False, conv=False, batch_norm=False, activation='relu', encoder_residual=False, encoder_dense_residual=False, label_smoothing=False): print('==================================================') print(' label_type: %s' % label_type) print(' encoder_type: %s' % encoder_type) print(' bidirectional: %s' % str(bidirectional)) print(' projection: %s' % str(projection)) print(' subsample: %s' % str(subsample)) print(' conv: %s' % str(conv)) print(' batch_norm: %s' % str(batch_norm)) print(' activation: %s' % activation) print(' encoder_residual: %s' % str(encoder_residual)) print(' encoder_dense_residual: %s' % str(encoder_dense_residual)) print(' label_smoothing: %s' % str(label_smoothing)) print('==================================================') if conv or encoder_type == 'cnn': # pattern 1 # conv_channels = [32, 32] # conv_kernel_sizes = [[41, 11], [21, 11]] # conv_strides = [[2, 2], [2, 1]] # poolings = [[], []] # pattern 2 (VGG like) conv_channels = [64, 64] conv_kernel_sizes = [[3, 3], [3, 3]] conv_strides = [[1, 1], [1, 1]] poolings = [[2, 2], [2, 2]] fc_list = [786, 786] else: conv_channels = [] conv_kernel_sizes = [] conv_strides = [] poolings = [] fc_list = [] # Load batch data splice = 1 num_stack = 1 if subsample or conv or encoder_type == 'cnn' else 2 xs, ys, x_lens, y_lens = generate_data( label_type=label_type, batch_size=2, num_stack=num_stack, splice=splice, backend='chainer') if label_type == 'char': num_classes = 27 map_fn = idx2char elif label_type == 'word': num_classes = 11 map_fn = idx2word # Load model model = CTC( input_size=xs[0].shape[-1] // splice // num_stack, # 120 encoder_type=encoder_type, encoder_bidirectional=bidirectional, encoder_num_units=256, encoder_num_proj=256 if projection else 0, encoder_num_layers=1 if not subsample else 2, fc_list=fc_list, dropout_input=0.1, dropout_encoder=0.1, num_classes=num_classes, parameter_init_distribution='uniform', parameter_init=0.1, recurrent_weight_orthogonal=False, init_forget_gate_bias_with_one=True, subsample_list=[] if not subsample else [True] * 2, num_stack=num_stack, splice=splice, input_channel=3, conv_channels=conv_channels, conv_kernel_sizes=conv_kernel_sizes, conv_strides=conv_strides, poolings=poolings, activation=activation, batch_norm=batch_norm, label_smoothing_prob=0.1 if label_smoothing else 0, weight_noise_std=0, encoder_residual=encoder_residual, encoder_dense_residual=encoder_dense_residual) # Count total parameters for name in sorted(list(model.num_params_dict.keys())): num_params = model.num_params_dict[name] print("%s %d" % (name, num_params)) print("Total %.3f M parameters" % (model.total_parameters / 1000000)) # Define optimizer learning_rate = 1e-3 model.set_optimizer( 'adam', # 'adadelta', learning_rate_init=learning_rate, weight_decay=1e-6, clip_grad_norm=5, lr_schedule=None, factor=None, patience_epoch=None) # Define learning rate controller lr_controller = Controller(learning_rate_init=learning_rate, backend='chainer', decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # GPU setting model.set_cuda(deterministic=False, benchmark=True) # Train model max_step = 300 start_time_step = time.time() for step in range(max_step): # Step for parameter update loss = model(xs, ys, x_lens, y_lens) model.optimizer.target.cleargrads() model.cleargrads() loss.backward() loss.unchain_backward() model.optimizer.update() # Inject Gaussian noise to all parameters if (step + 1) % 10 == 0: # Compute loss loss = model(xs, ys, x_lens, y_lens, is_eval=True) # Decode best_hyps, _, _ = model.decode(xs, x_lens, beam_width=1) # TODO: fix beam search str_ref = map_fn(ys[0, :y_lens[0]]) str_hyp = map_fn(best_hyps[0]) # Compute accuracy try: if label_type == 'char': ler, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=True) elif label_type == 'word': ler, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) except: ler = 1 duration_step = time.time() - start_time_step print('Step %d: loss=%.3f / ler=%.3f / lr=%.5f (%.3f sec)' % (step + 1, loss, ler, learning_rate, duration_step)) start_time_step = time.time() # Visualize print('Ref: %s' % str_ref) print('Hyp: %s' % str_hyp) if ler < 0.05: print('Modle is Converged.') break # Update learning rate model.optimizer, learning_rate = lr_controller.decay_lr( optimizer=model.optimizer, learning_rate=learning_rate, epoch=step, value=ler)
def decode(model, dataset, eval_batch_size, beam_width, length_penalty, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam length_penalty (float): save_path (string): path to save decoding results """ if 'word' in dataset.label_type: map_fn = Idx2word(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_WORD else: map_fn = Idx2char(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_CHAR if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, _, best_hyps_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=max_decode_len) else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = map_fn(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = map_fn(best_hyps[b]) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp: %s' % str_hyp.replace('_', ' ')) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = map_fn(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) try: if dataset.label_type == 'word' or dataset.label_type == 'kanji_wb': wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=re.sub( r'(.*)[_]*>(.*)', r'\1', str_hyp).split('_'), normalize=True) print('WER: %.3f %%' % (wer * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: wer_ctc, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp_ctc.split('_'), normalize=True) print('WER (CTC): %.3f %%' % (wer_ctc * 100)) else: cer, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list( re.sub(r'(.*)>(.*)', r'\1', str_hyp).replace('_', '')), normalize=True) print('CER: %.3f %%' % (cer * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: cer_ctc, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp_ctc.replace('_', '')), normalize=True) print('CER (CTC): %.3f %%' % (cer_ctc * 100)) except: print('--- skipped ---') if is_new_epoch: break
def decode(model, dataset, beam_width, max_decode_len, eval_batch_size=None, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size vocab_file_path = '../metrics/vocab_files/' + \ dataset.label_type + '_' + dataset.data_size + '.txt' if dataset.label_type == 'character': map_fn = Idx2char(vocab_file_path) elif dataset.label_type == 'character_capital_divide': map_fn = Idx2char(vocab_file_path, capital_divide=True) else: map_fn = Idx2word(vocab_file_path) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc( batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = map_fn(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = map_fn(best_hyps[b]) if model.model_type == 'attention': str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp: %s' % str_hyp.replace('_', ' ')) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = map_fn(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) # Compute CER if 'word' in dataset.label_type: wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) print('WER: %f %%' % (wer * 100)) if model.ctc_loss_weight > 0: wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp_ctc.split('_'), normalize=True) print('WER (CTC): %f %%' % (wer_ctc * 100)) else: cer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) print('CER: %f %%' % (cer * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: cer_ctc, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=True) print('CER (CTC): %f %%' % (cer_ctc * 100)) if is_new_epoch: break