def plot(model, dataset, eval_batch_size=None, save_path=None, space_index=None): """ Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save figures of CTC posteriors space_index (int, optional): """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size # Clean directory if isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) vocab_file_path = '../metrics/vocab_files/' + \ dataset.label_type + '_' + dataset.data_size + '.txt' if dataset.label_type == 'character': map_fn = Idx2char(vocab_file_path) elif dataset.label_type == 'character_capital_divide': map_fn = Idx2char(vocab_file_path, capital_divide=True) else: map_fn = Idx2word(vocab_file_path) for batch, is_new_epoch in dataset: # Get CTC probs probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1) # NOTE: probs: '[B, T, num_classes]' # Decode best_hyps _ = model.decode(batch['xs'], batch['x_lens'], beam_width=1) # Visualize for b in range(len(batch['xs'])): # Convert from list of index to string str_pred = map_fn(best_hyps[b]) speaker, book = batch['input_names'][b].split('-')[:2] plot_ctc_probs( probs[b, :batch['x_lens'][b], :], frame_num=batch['x_lens'][b], num_stack=dataset.num_stack, space_index=space_index, str_pred=str_pred, save_path=mkdir_join(save_path, speaker, book, batch['input_names'][b] + '.png')) if is_new_epoch: break
def plot(model, dataset, eval_batch_size, save_path=None): """ Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model save_path (string): path to save figures of CTC posteriors """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size # Clean directory if isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) idx2word = Idx2word(dataset.vocab_file_path) idx2char = Idx2char( dataset.vocab_file_path, capital_divide=dataset.label_type_sub == 'character_capital_divide') for batch, is_new_epoch in dataset: # Get CTC probs probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1) probs_sub = model.posteriors(batch['xs'], batch['x_lens'], is_sub_task=True, temperature=1) # NOTE: probs: '[B, T, num_classes]' # NOTE: probs_sub: '[B, T, num_classes_sub]' # Decode best_hyps = model.decode(batch['xs'], batch['x_lens'], beam_width=1) best_hyps_sub = model.decode(batch['xs'], batch['x_lens'], beam_width=1, is_sub_task=True) # Visualize for b in range(len(batch['xs'])): # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) speaker = batch['input_names'][b].split('_')[0] plot_hierarchical_ctc_probs(probs[b, :batch['x_lens'][b], :], probs_sub[b, :batch['x_lens'][b], :], frame_num=batch['x_lens'][b], num_stack=dataset.num_stack, str_hyp=str_hyp, str_hyp_sub=str_hyp_sub, save_path=mkdir_join( save_path, speaker, batch['input_names'][b] + '.png')) if is_new_epoch: break
def plot(model, dataset, beam_width, eval_batch_size=None, save_path=None): """Visualize attention weights of attetnion-based model. Args: model: model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam eval_batch_size (int, optional): the batch size when evaluating the model save_path (string, optional): path to save attention weights plotting """ # Clean directory if save_path is not None and isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) if 'char' in dataset.label_type: map_fn = Idx2char(dataset.vocab_file_path, capital_divide=dataset.label_type == 'character_capital_divide', return_list=True) max_decode_len = MAX_DECODE_LEN_CHAR else: map_fn = Idx2word(dataset.vocab_file_path, return_list=True) max_decode_len = MAX_DECODE_LEN_WORD for batch, is_new_epoch in dataset: # Decode best_hyps, aw, perm_idx = model.attention_weights( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = map_fn(ys[b][:y_lens[b]]) token_list = map_fn(best_hyps[b]) speaker = '_'.join(batch['input_names'][b].split('_')[:2]) plot_attention_weights( aw[b, :len(token_list), :batch['x_lens'][b]], label_list=token_list, spectrogram=batch['xs'][b, :, :dataset.input_freq], str_ref=str_ref, save_path=mkdir_join(save_path, speaker, batch['input_names'][b] + '.png'), figsize=(20, 8)) if is_new_epoch: break
def check_loading(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(sort_utt)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( data_type=data_type, label_type=label_type, batch_size=64, eos_index=1, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') if label_type in ['character', 'character_capital_divide']: map_fn_ctc = Idx2char( map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt') map_fn_att = Idx2char( map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt') else: map_fn_ctc = Idx2phone( map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt') map_fn_att = Idx2phone( map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt') for data, is_new_epoch in dataset: inputs, att_labels, ctc_labels, inputs_seq_len, att_labels_seq_len, input_names = data att_str_true = map_fn_att(att_labels[0][0: att_labels_seq_len[0]]) ctc_str_true = map_fn_ctc(ctc_labels[0]) att_str_true = re.sub(r'_', ' ', att_str_true) ctc_str_true = re.sub(r'_', ' ', ctc_str_true) print('----- %s ----- (epoch: %.3f)' % (input_names[0], dataset.epoch_detail)) print(att_str_true) print(ctc_str_true)
def check(self, ss_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % 'kana') print(' ss_type: %s' % ss_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') map_file_path = '../../metrics/mapping_files/kana_' + ss_type + '.txt' num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( data_type=data_type, label_type='kana', ss_type=ss_type, batch_size=64, map_file_path=map_file_path, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') map_fn = Idx2char(map_file_path=map_file_path) for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, labels_seq_len, input_names = data if data_type == 'train': for i_batch, l_batch in zip(inputs[0], labels[0]): if len(np.where(l_batch == dataset.padded_value)[0]) > 0: if i_batch.shape[0] < np.where(l_batch == dataset.padded_value)[0][0]: raise ValueError( 'input length must be longer than label length.') else: if i_batch.shape[0] < len(l_batch): raise ValueError( 'input length must be longer than label length.') if data_type != 'test': str_true = map_fn(labels[0][0][:labels_seq_len[0][0]]) else: str_true = labels[0][0][0] print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0][0].shape) print(str_true) if dataset.epoch_detail >= 0.2: break
def check(self, label_type_main, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type_main: %s' % label_type_main) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_type=data_type, label_type_main=label_type_main, label_type_sub='phone61', batch_size=64, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') idx2char = Idx2char(map_file_path='../../metrics/mapping_files/' + label_type_main + '.txt') idx2phone = Idx2phone( map_file_path='../../metrics/mapping_files/phone61.txt') for data, is_new_epoch in dataset: inputs, labels_char, labels_phone, inputs_seq_len, input_names = data if data_type != 'test': str_true_char = idx2char(labels_char[0][0]) str_true_phone = idx2phone(labels_phone[0][0]) else: str_true_char = labels_char[0][0][0] str_true_phone = labels_phone[0][0][0] print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(str_true_char) print(str_true_phone)
def plot(model, dataset, eval_batch_size, beam_width, beam_width_sub, length_penalty, save_path=None): """Visualize attention weights of Attetnion-based model. Args: model: model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam in the main task beam_width_sub: (int): the size of beam in the sub task length_penalty (float): save_path (string, optional): path to save attention weights plotting """ # Clean directory if save_path is not None and isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) map_fn_main = Idx2word(dataset.vocab_file_path, return_list=True) map_fn_sub = Idx2char(dataset.vocab_file_path_sub, return_list=True) for batch, is_new_epoch in dataset: # Decode best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_WORD) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, task_index=1) for b in range(len(batch['xs'])): word_list = map_fn_main(best_hyps[b]) char_list = map_fn_sub(best_hyps_sub[b]) speaker = batch['input_names'][b].split('_')[0] plot_hierarchical_attention_weights( aw[b][:len(word_list), :batch['x_lens'][b]], aw_sub[b][:len(char_list), :batch['x_lens'][b]], label_list=word_list, label_list_sub=char_list, spectrogram=batch['xs'][b, :, :dataset.input_freq], save_path=mkdir_join(save_path, speaker, batch['input_names'][b] + '.png'), figsize=(40, 8) ) if is_new_epoch: break
def check(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') map_file_path = '../../metrics/mapping_files/' + label_type + '.txt' num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( data_type=data_type, label_type=label_type, batch_size=64, map_file_path=map_file_path, max_epoch=1, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') if label_type in ['character', 'character_capital_divide']: map_fn = Idx2char(map_file_path=map_file_path) else: map_fn = Idx2phone(map_file_path=map_file_path) for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, labels_seq_len, input_names = data if data_type != 'test': str_true = map_fn(labels[0][0][:labels_seq_len[0][0]]) else: str_true = labels[0][0][0] print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0][0].shape) print(str_true)
def do_eval_cer(save_paths, dataset, data_type, label_type, num_classes, beam_width, temperature_infer, is_test=False, progressbar=False): """Evaluate trained model by Character Error Rate. Args: save_paths (list): dataset: An instance of a `Dataset` class data_type (string): label_type (string): character num_classes (int): beam_width (int): the size of beam temperature (int): temperature in the inference stage is_test (bool, optional): set to True when evaluating by the test set progressbar (bool, optional): if True, visualize the progressbar Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') char2idx = Char2idx( map_file_path='../metrics/mapping_files/character.txt') else: raise TypeError # Define decoder decoder = BeamSearchDecoder(space_index=char2idx(str_char='_')[0], blank_index=num_classes - 1) ################################################## # Compute mean probabilities ################################################## if progressbar: pbar = tqdm(total=len(dataset)) cer_mean, wer_mean = 0, 0 for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data batch_size = inputs[0].shape[0] for i_batch in range(batch_size): probs_ensemble = None for i_model in range(len(save_paths)): # Load posteriors speaker = input_names[0][i_batch].split('-')[0] prob_save_path = join(save_paths[i_model], 'temp' + str(temperature_infer), data_type, 'probs_utt', speaker, input_names[0][i_batch] + '.npy') probs_model_i = np.load(prob_save_path) # NOTE: probs_model_i: `[T, num_classes]` # Sum over probs if probs_ensemble is None: probs_ensemble = probs_model_i else: probs_ensemble += probs_model_i # Compute mean posteriors probs_ensemble /= len(save_paths) # Decode per utterance labels_pred, scores = decoder( probs=probs_ensemble[np.newaxis, :, :], seq_len=inputs_seq_len[0][i_batch:i_batch + 1], beam_width=beam_width) # Convert from list of index to string if is_test: str_true = labels_true[0][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[0][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[0]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_pred.split('_'), hyp=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_true.split('_'), # hyp=str_pred.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= (len(dataset)) wer_mean /= (len(dataset)) # TODO: Fix this return cer_mean, wer_mean
def do_eval_cer(session, decode_ops, model, dataset, label_type, train_data_size, is_test=False, eval_batch_size=None, progressbar=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_ops (list): operations for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): kanji or kanji or kanji_divide or kana_divide train_data_size (string): train_subset or train_fullset is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar Return: cer_mean (float): An average of CER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if 'kanji' in label_type: map_file_path = '../metrics/mapping_files/' + \ label_type + '_' + train_data_size + '.txt' elif 'kana' in label_type: map_file_path = '../metrics/mapping_files/' + label_type + '.txt' else: raise TypeError idx2char = Idx2char(map_file_path=map_file_path) cer_mean = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[ i_device] feed_dict[model.keep_prob_encoder_pl_list[i_device]] = 1.0 feed_dict[model.keep_prob_decoder_pl_list[i_device]] = 1.0 feed_dict[model.keep_prob_embedding_pl_list[i_device]] = 1.0 labels_pred_list = session.run(decode_ops, feed_dict=feed_dict) for i_device in range(len(labels_pred_list)): for i_batch in range(len(inputs[i_device])): # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch] [1:labels_seq_len[i_device][i_batch] - 1]) str_pred = idx2char( labels_pred_list[i_device][i_batch]).split('>')[0] # NOTE: Trancate by <EOS> # Remove garbage labels str_true = re.sub(r'[_NZー・<>]+', '', str_true) str_pred = re.sub(r'[_NZー・<>]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= len(dataset) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean
def eval_char(models, dataset, beam_width, max_decode_len, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1): """Evaluate trained model by Character Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. length_penalty (float, optional): eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() if models[0].model_type in ['ctc', 'attention']: idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path, capital_divide=(dataset.label_type == 'character_capital_divide')) else: idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path_sub, capital_divide=( dataset.label_type_sub == 'character_capital_divide')) # Read GLM file glm = GLM( glm_path= '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm' ) wer, cer = 0, 0 sub_word, ins_word, del_word = 0, 0, 0 sub_char, ins_char, del_char = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # TODO: add CTC ensemble # Decode model = models[0] # TODO: fix this if model.model_type in ['ctc', 'attention']: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty, task_index=1) ys = batch['ys_sub'][perm_idx] y_lens = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2char(best_hyps[b]) if 'attention' in model.model_type: str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref, glm) str_hyp = fix_trans(str_hyp, glm) if len(str_ref) == 0: if progressbar: pbar.update(1) continue try: # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub_word /= num_words ins_word /= num_words del_word /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars df_wer_cer = pd.DataFrame( { 'SUB': [sub_word * 100, sub_char * 100], 'INS': [ins_word * 100, ins_char * 100], 'DEL': [del_word * 100, del_char * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER', 'CER']) return wer, cer, df_wer_cer
def check(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1, num_gpu=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print(' num_gpu: %d' % num_gpu) print('========================================') if 'kana' in label_type: map_file_path = '../../metrics/mapping_files/' + label_type + '.txt' elif 'kanji' in label_type: map_file_path = '../../metrics/mapping_files/' + \ label_type + '_train_subset.txt' num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_type=data_type, train_data_size='train_subset', label_type=label_type, map_file_path=map_file_path, batch_size=64, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True, num_gpu=num_gpu) print('=> Loading mini-batch...') idx2char = Idx2char(map_file_path) # idx2word = Idx2word(map_file_path) for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, labels_seq_len, input_names = data if data_type == 'train': for i, l in zip(inputs[0], labels[0]): if len(i) < len(l): raise ValueError( 'input length must be longer than label length.') if num_gpu > 1: for inputs_gpu in inputs: print(inputs_gpu.shape) if 'eval' in data_type: str_true = labels[0][0][0] else: # if 'word' in label_type: # str_true = '_'.join(idx2word(labels[0][0])) # else: str_true = idx2char(labels[0][0][0:labels_seq_len[0][0]]) print('----- %s (epoch: %.3f) -----' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0].shape) print(str_true) if dataset.epoch_detail >= 0.1: break
def decode(model, dataset, beam_width, max_decode_len, max_decode_len_sub, eval_batch_size=None, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. max_decode_len_sub (int) eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) if dataset.label_type_sub == 'character': idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub) elif dataset.label_type_sub == 'character_capital_divide': idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub, capital_divide=True) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'charseq_attention': best_hyps, best_hyps_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=100) else: best_hyps, perm_idx = model.decode(batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) best_hyps_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, is_sub_task=True) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) if model.model_type != 'hierarchical_ctc': str_hyp = str_hyp.split('>')[0] str_hyp_sub = str_hyp_sub.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] if len(str_hyp_sub) > 0 and str_hyp_sub[-1] == '_': str_hyp_sub = str_hyp_sub[:-1] ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp (main): %s' % str_hyp.replace('_', ' ')) # print('Ref (sub): %s' % str_ref_sub.replace('_', ' ')) print('Hyp (sub): %s' % str_hyp_sub.replace('_', ' ')) wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=True) print('WER: %f %%' % (wer * 100)) cer, _, _, _ = compute_wer(ref=list(str_ref_sub.replace('_', '')), hyp=list(str_hyp_sub.replace('_', '')), normalize=True) print('CER: %f %%' % (cer * 100)) if is_new_epoch: break
def decode(session, decode_op, model, dataset, label_type, is_test=True, save_path=None): """Visualize label outputs of CTC model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): phone39 or phone48 or phone61 or character or character_capital_divide is_test (bool, optional): save_path (string, optional): path to save decoding results """ if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred_st = session.run(decode_op, feed_dict=feed_dict) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=batch_size) except IndexError: # no output labels_pred = [''] for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if 'char' in label_type: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn(labels_true[0][i_batch]) str_pred = map_fn(labels_pred[i_batch]) else: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn(labels_true[0][i_batch]) str_pred = map_fn(labels_pred[i_batch]) print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break
def do_eval_cer(session, decode_op, model, dataset, label_type, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ # Reset data counter dataset.reset() if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/ctc/character.txt') elif label_type == 'character_capital_divide': idx2char = Idx2char( map_file_path= '../metrics/mapping_files/ctc/character_capital_divide.txt', capital_divide=True, space_mark='_') cer_mean, wer_mean = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, labels_true, _, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } batch_size_each = len(inputs) labels_pred_st = session.run(decode_op, feed_dict=feed_dict) labels_pred = sparsetensor2list(labels_pred_st, batch_size_each) for i_batch in range(batch_size_each): # Convert from list of index to string str_true = idx2char(labels_true[i_batch]) str_pred = idx2char(labels_pred[i_batch]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\'\":;!?,.-]+', "", str_true) str_pred = re.sub(r'[\'\":;!?,.-]+', "", str_pred) # Compute WER wer_mean += compute_wer(hyp=str_pred.split('_'), ref=str_true.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print(substitute) # print(insert) # print(delete) # Remove spaces str_pred = re.sub(r'[_]+', "", str_pred) str_true = re.sub(r'[_]+', "", str_true) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break cer_mean /= len(dataset) wer_mean /= len(dataset) return cer_mean, wer_mean
def decode(session, decode_op, model, dataset, label_type, train_data_size, is_test=True, save_path=None): """Visualize label outputs of CTC model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): kanji or kanji or kanji_divide or kana_divide train_data_size (string): train_subset or train_fullset is_test (bool, optional): set to True when evaluating by the test set save_path (string, optional): path to save decoding results """ if 'kanji' in label_type: map_file_path = '../metrics/mapping_files/' + \ label_type + '_' + train_data_size + '.txt' elif 'kana' in label_type: map_file_path = '../metrics/mapping_files/' + label_type + '.txt' else: raise TypeError idx2char = Idx2char(map_file_path=map_file_path) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } # Decode batch_size = inputs[0].shape[0] labels_pred_st = session.run(decode_op, feed_dict=feed_dict) no_output_flag = False try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=batch_size) except IndexError: # no output no_output_flag = True # Visualize for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if is_test: str_true = labels_true[0][i_batch][0] else: str_true = idx2char(labels_true[0][i_batch]) if no_output_flag: str_pred = '' else: str_pred = idx2char(labels_pred[i_batch]) print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break
def __init__(self, data_save_path, backend, input_freq, use_delta, use_double_delta, data_type, data_size, label_type, label_type_sub, batch_size, max_epoch=None, splice=1, num_stack=1, num_skip=1, min_frame_num=40, shuffle=False, sort_utt=False, reverse=False, sort_stop_epoch=None, num_gpus=1, tool='htk', num_enque=None, dynamic_batching=False): """A class for loading dataset. Args: data_save_path (string): path to saved data backend (string): pytorch or chainer input_freq (int): the number of dimensions of acoustics use_delta (bool): if True, use the delta feature use_double_delta (bool): if True, use the acceleration feature data_type (string): train or dev_clean or dev_other or test_clean or test_other data_size (string): 100 or 460 or 960 label_type (string): word label_type_sub (string): characater or characater_capital_divide batch_size (int): the size of mini-batch max_epoch (int): the max epoch. None means infinite loop. splice (int): frames to splice. Default is 1 frame. num_stack (int): the number of frames to stack num_skip (int): the number of frames to skip shuffle (bool): if True, shuffle utterances. This is disabled when sort_utt is True. sort_utt (bool): if True, sort all utterances in the ascending order reverse (bool): if True, sort utteraces in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order num_gpus (int): the number of GPUs tool (string): htk or librosa or python_speech_features num_enque (int): the number of elements to enqueue dynamic_batching (bool): if True, batch size will be chainged dynamically in training """ self.backend = backend self.input_freq = input_freq self.use_delta = use_delta self.use_double_delta = use_double_delta self.data_type = data_type self.data_size = data_size self.label_type = label_type self.label_type_sub = label_type_sub self.batch_size = batch_size * num_gpus self.max_epoch = max_epoch self.splice = splice self.num_stack = num_stack self.num_skip = num_skip self.shuffle = shuffle self.sort_utt = sort_utt self.sort_stop_epoch = sort_stop_epoch self.num_gpus = num_gpus self.tool = tool self.num_enque = num_enque self.dynamic_batching = dynamic_batching self.is_test = True if 'test' in data_type else False self.vocab_file_path = join( data_save_path, 'vocab', data_size, label_type + '.txt') self.idx2word = Idx2word(self.vocab_file_path) self.word2idx = Word2idx(self.vocab_file_path) self.vocab_file_path_sub = join( data_save_path, 'vocab', data_size, label_type_sub + '.txt') self.idx2char = Idx2char( self.vocab_file_path_sub, capital_divide=label_type_sub == 'character_capital_divide') self.char2idx = Char2idx( self.vocab_file_path_sub, capital_divide=label_type_sub == 'character_capital_divide') super(Dataset, self).__init__(vocab_file_path=self.vocab_file_path, vocab_file_path_sub=self.vocab_file_path_sub) # Load dataset file dataset_path = join( data_save_path, 'dataset', tool, data_size, data_type, label_type + '.csv') dataset_path_sub = join( data_save_path, 'dataset', tool, data_size, data_type, label_type_sub + '.csv') df = pd.read_csv(dataset_path) df = df.loc[:, ['frame_num', 'input_path', 'transcript']] df_sub = pd.read_csv(dataset_path_sub) df_sub = df_sub.loc[:, ['frame_num', 'input_path', 'transcript']] # Remove inappropriate utteraces if not self.is_test: logger.info('Original utterance num: %d' % len(df)) df = df[df.apply( lambda x: min_frame_num <= x['frame_num'], axis=1)] logger.info('Restricted utterance num: %d' % len(df)) # Sort paths to input & label if sort_utt: df = df.sort_values(by='frame_num', ascending=not reverse) df_sub = df_sub.sort_values(by='frame_num', ascending=not reverse) else: df = df.sort_values(by='input_path', ascending=True) df_sub = df_sub.sort_values(by='input_path', ascending=True) assert len(df) == len(df_sub) self.df = df self.df_sub = df_sub self.rest = set(list(df.index))
def decode(session, decode_op_main, decode_op_sub, model, dataset, label_type_main, label_type_sub, is_test=True, save_path=None): """Visualize label outputs of Multi-task CTC model. Args: session: session of training model decode_op_main: operation for decoding in the main task decode_op_sub: operation for decoding in the sub task model: the model to evaluate dataset: An instance of a `Dataset` class label_type_main (string): character or character_capital_divide label_type_sub (string): phone39 or phone48 or phone61 is_test (bool, optional): save_path (string, optional): path to save decoding results """ idx2char = Idx2char(map_file_path='../metrics/mapping_files/' + label_type_main + '.txt') idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type_sub + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true_char, labels_true_phone, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred_char_st, labels_pred_phone_st = session.run( [decode_op_main, decode_op_sub], feed_dict=feed_dict) try: labels_pred_char = sparsetensor2list(labels_pred_char_st, batch_size=batch_size) except: # no output labels_pred_char = [''] try: labels_pred_phone = sparsetensor2list(labels_pred_char_st, batch_size=batch_size) except: # no output labels_pred_phone = [''] for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if is_test: str_true_char = labels_true_char[0][i_batch][0].replace( '_', ' ') str_true_phone = labels_true_phone[0][i_batch][0] else: str_true_char = idx2char(labels_true_char[0][i_batch]) str_true_phone = idx2phone(labels_true_phone[0][i_batch]) str_pred_char = idx2char(labels_pred_char[i_batch]) str_pred_phone = idx2phone(labels_pred_phone[i_batch]) print('Ref (char): %s' % str_true_char) print('Hyp (char): %s' % str_pred_char) print('Ref (phone): %s' % str_true_phone) print('Hyp (phone): %s' % str_pred_phone) if is_new_epoch: break
def decode(session, decode_op, model, dataset, label_type, is_test=False, save_path=None): """Visualize label outputs of Attention-based model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): phone39 or phone48 or phone61 or character or character_capital_divide is_test (bool, optional): save_path (string): path to save decoding results """ if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, labels_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_encoder_pl_list[0]: 1.0, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred = session.run(decode_op, feed_dict=feed_dict) for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn( labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1]) # NOTE: Exclude <SOS> and <EOS> str_pred = map_fn(labels_pred[i_batch]).split('>')[0] # NOTE: Trancate by <EOS> if 'phone' in label_type: # Remove the last space if str_pred[-1] == ' ': str_pred = str_pred[:-1] print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break
def plot(model, dataset, beam_width, beam_width_sub, eval_batch_size=None, a2c_oracle=False, save_path=None): """Visualize attention weights of Attetnion-based model. Args: model: model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam i nteh main task beam_width_sub: (int): the size of beam in the sub task eval_batch_size (int, optional): the batch size when evaluating the model a2c_oracle (bool, optional): save_path (string, optional): path to save attention weights plotting """ # Clean directory if save_path is not None and isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) idx2word = Idx2word(dataset.vocab_file_path, return_list=True) idx2char = Idx2char(dataset.vocab_file_path_sub, return_list=True) for batch, is_new_epoch in dataset: batch_size = len(batch['xs']) if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list(batch['ys_sub'][b][0])): max_label_num = len(list(batch['ys_sub'][b][0])) ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size, ), dtype=np.int32) for b in range(batch_size): indices = char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] else: ys_sub = None y_lens_sub = None best_hyps, best_hyps_sub, aw, aw_sub, aw_dec = model.attention_weights( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) for b in range(len(batch['xs'])): word_list = idx2word(best_hyps[b]) if 'word' in dataset.label_type_sub: char_list = idx2word(best_hyps_sub[b]) else: char_list = idx2char(best_hyps_sub[b]) # word to acoustic & character to acoustic plot_hierarchical_attention_weights( aw[b][:len(word_list), :batch['x_lens'][b]], aw_sub[b][:len(char_list), :batch['x_lens'][b]], label_list=word_list, label_list_sub=char_list, spectrogram=batch['xs'][b, :, :dataset.input_freq], save_path=mkdir_join(save_path, batch['input_names'][b] + '.png'), figsize=(40, 8)) # word to characater plot_word2char_attention_weights( aw_dec[b][:len(word_list), :len(char_list)], label_list=word_list, label_list_sub=char_list, save_path=mkdir_join( save_path, batch['input_names'][b] + '_word2char.png'), figsize=(40, 8)) # with open(join(save_path, speaker, batch['input_names'][b] + '.txt'), 'w') as f: # f.write(batch['ys'][b][0]) if is_new_epoch: break
def check(self, label_type, label_type_sub, data_type='dev_clean', data_size='100h', backend='pytorch', shuffle=False, sort_utt=True, sort_stop_epoch=None, frame_stacking=False, splice=1, num_gpus=1): print('========================================') print(' backend: %s' % backend) print(' label_type: %s' % label_type) print(' label_type_sub: %s' % label_type_sub) print(' data_type: %s' % data_type) print(' data_size: %s' % data_size) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print(' num_gpus: %d' % num_gpus) print('========================================') vocab_file_path = '../../metrics/vocab_files/' + \ label_type + '_' + data_size + '.txt' vocab_file_path_sub = '../../metrics/vocab_files/' + \ label_type_sub + '_' + data_size + '.txt' num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(backend=backend, input_channel=40, use_delta=True, use_double_delta=True, data_type=data_type, data_size=data_size, label_type=label_type, label_type_sub=label_type_sub, batch_size=64, vocab_file_path=vocab_file_path, vocab_file_path_sub=vocab_file_path_sub, max_epoch=1, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, reverse=True, sort_stop_epoch=sort_stop_epoch, num_gpus=num_gpus, save_format='numpy', num_enque=None) print('=> Loading mini-batch...') idx2word = Idx2word(vocab_file_path, space_mark=' ') idx2char = Idx2char(vocab_file_path_sub) for batch, is_new_epoch in dataset: if data_type == 'train' and backend == 'pytorch': for i in range(len(batch['xs'])): if batch['xs'].shape[1] < batch['ys'].shape[1]: raise ValueError( 'input length must be longer than label length.') if dataset.is_test: str_ref = batch['ys'][0][0] str_ref_sub = batch['ys_sub'][0][0] else: str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]]) str_ref_sub = idx2char( batch['ys_sub'][0][:batch['y_lens_sub'][0]]) print('----- %s (epoch: %.3f, batch: %d) -----' % (batch['input_names'][0], dataset.epoch_detail, len(batch['xs']))) print('=' * 20) print(str_ref) print('-' * 10) print(str_ref_sub) print('x_lens: %d' % (batch['x_lens'][0] * num_stack)) if not dataset.is_test: print('y_lens (word): %d' % batch['y_lens'][0]) print('y_lens_sub (char): %d' % batch['y_lens_sub'][0]) if dataset.epoch_detail >= 0.01: break
def eval_word(models, dataset, beam_width, max_decode_len, beam_width_sub=1, max_decode_len_sub=300, eval_batch_size=None, length_penalty=0, progressbar=False, temperature=1, resolving_unk=False, a2c_oracle=False): """Evaluate trained model by Word Error Rate. Args: models (list): the models to evaluate dataset: An instance of a `Dataset' class max_decode_len (int): the length of output sequences to stop prediction. This is used for seq2seq models. beam_width_sub (int, optional): the size of beam in ths sub task This is used for the nested attention max_decode_len_sub (int, optional): the length of output sequences to stop prediction. This is used for the nested attention eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar temperature (int, optional): resolving_unk (bool, optional): a2c_oracle (bool, optional): Returns: wer (float): Word error rate df_wer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2word = Idx2word(dataset.vocab_file_path) if models[0].model_type == 'nested_attention': char2idx = Char2idx(dataset.vocab_file_path_sub) if models[0] in ['ctc', 'attention'] and resolving_unk: idx2char = Idx2char(dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') wer = 0 sub, ins, dele, = 0, 0, 0 num_words = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) batch_size = len(batch['xs']) # Decode if len(models) > 1: assert models[0].model_type in ['ctc'] for i, model in enumerate(models): probs, x_lens, perm_idx = model.posteriors( batch['xs'], batch['x_lens']) if i == 0: probs_ensenmble = probs else: probs_ensenmble += probs probs_ensenmble /= len(models) best_hyps = models[0].decode_from_probs(probs_ensenmble, x_lens, beam_width=1) else: model = models[0] # TODO: fix this if model.model_type == 'nested_attention': if a2c_oracle: if dataset.is_test: max_label_num = 0 for b in range(batch_size): if max_label_num < len(list( batch['ys_sub'][b][0])): max_label_num = len(list( batch['ys_sub'][b][0])) ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32) ys_sub -= 1 # pad with -1 y_lens_sub = np.zeros((batch_size, ), dtype=np.int32) for b in range(batch_size): indices = char2idx(batch['ys_sub'][b][0]) ys_sub[b, :len(indices)] = indices y_lens_sub[b] = len(indices) # NOTE: transcript is seperated by space('_') else: ys_sub = batch['ys_sub'] y_lens_sub = batch['y_lens_sub'] best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=max_decode_len, max_decode_len_sub=max_label_num if a2c_oracle else max_decode_len_sub, length_penalty=length_penalty, teacher_forcing=a2c_oracle, ys_sub=ys_sub, y_lens_sub=y_lens_sub) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) if resolving_unk: best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len_sub, length_penalty=length_penalty, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(batch_size): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2word(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2word(best_hyps[b]) if dataset.label_type == 'word': str_hyp = re.sub(r'(.*)_>(.*)', r'\1', str_hyp) else: str_hyp = re.sub(r'(.*)>(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> ############################## # Resolving UNK ############################## if resolving_unk and 'OOV' in str_hyp: str_hyp = resolve_unk(str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) str_hyp = str_hyp.replace('*', '') ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[@>]+', '', str_ref) str_hyp = re.sub(r'[@>]+', '', str_hyp) # NOTE: @ means noise # Remove consecutive spaces str_ref = re.sub(r'[_]+', '_', str_ref) str_hyp = re.sub(r'[_]+', '_', str_hyp) # Compute WER try: wer_b, sub_b, ins_b, del_b = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub += sub_b ins += ins_b dele += del_b num_words += len(str_ref.split('_')) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() wer /= num_words sub /= num_words ins /= num_words dele /= num_words df_wer = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['WER']) return wer, df_wer
def check(self, label_type, label_type_sub, data_type='dev', data_size='300h', backend='pytorch', shuffle=False, sort_utt=True, sort_stop_epoch=None, frame_stacking=False, splice=1, num_gpus=1): print('========================================') print(' backend: %s' % backend) print(' label_type: %s' % label_type) print(' label_type_sub: %s' % label_type_sub) print(' data_type: %s' % data_type) print(' data_size: %s' % data_size) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print(' num_gpus: %d' % num_gpus) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( # data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi/' + data_size, data_save_path='/n/sd8/inaguma/corpus/swbd/kaldi', backend=backend, input_freq=40, use_delta=True, use_double_delta=True, data_type=data_type, data_size=data_size, label_type=label_type, label_type_sub=label_type_sub, batch_size=64, max_epoch=1, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, reverse=True, sort_stop_epoch=sort_stop_epoch, num_gpus=num_gpus, tool='htk', num_enque=None) print('=> Loading mini-batch...') idx2word = Idx2word(dataset.vocab_file_path) idx2char = Idx2char(dataset.vocab_file_path_sub) for batch, is_new_epoch in dataset: if data_type == 'train' and backend == 'pytorch': for i in range(len(batch['xs'])): if batch['xs'].shape[1] < batch['ys'].shape[1]: raise ValueError( 'input length must be longer than label length.') if dataset.is_test: str_ref = batch['ys'][0][0] str_ref = str_ref.lower() str_ref = str_ref.replace('(', '').replace(')', '') str_ref_sub = batch['ys_sub'][0][0] str_ref_sub = str_ref_sub.lower() str_ref_sub = str_ref_sub.replace('(', '').replace(')', '') else: str_ref = idx2word(batch['ys'][0][:batch['y_lens'][0]]) str_ref_sub = idx2char( batch['ys_sub'][0][:batch['y_lens_sub'][0]]) print('----- %s (epoch: %.3f, batch: %d) -----' % (batch['input_names'][0], dataset.epoch_detail, len(batch['xs']))) print('=' * 20) print(str_ref) print('-' * 10) print(str_ref_sub) print('x_lens: %d' % (batch['x_lens'][0] * num_stack)) if not dataset.is_test: print('y_lens (word): %d' % batch['y_lens'][0]) print('y_lens_sub (char): %d' % batch['y_lens_sub'][0]) if dataset.epoch_detail >= 1: break
def do_eval_cer(session, decode_ops, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Character Error Rate. Args: session: session of training model decode_ops: list of operations for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): character or character_capital_divide is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Return: cer_mean (float): An average of CER wer_mean (float): An average of WER """ assert isinstance(decode_ops, list), "decode_ops must be a list." batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if label_type == 'character': idx2char = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': idx2char = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True, space_mark='_') else: raise TypeError cer_mean, wer_mean = 0, 0 skip_data_num = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = {} for i_device in range(len(decode_ops)): feed_dict[model.inputs_pl_list[i_device]] = inputs[i_device] feed_dict[model.inputs_seq_len_pl_list[i_device]] = inputs_seq_len[ i_device] feed_dict[model.keep_prob_pl_list[i_device]] = 1.0 labels_pred_st_list = session.run(decode_ops, feed_dict=feed_dict) for i_device, labels_pred_st in enumerate(labels_pred_st_list): batch_size_device = len(inputs[i_device]) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size_device) for i_batch in range(batch_size_device): # Convert from list of index to string if is_test: str_true = labels_true[i_device][i_batch][0] # NOTE: transcript is seperated by space('_') else: str_true = idx2char(labels_true[i_device][i_batch], padded_value=dataset.padded_value) str_pred = idx2char(labels_pred[i_batch]) # Remove consecutive spaces str_pred = re.sub(r'[_]+', '_', str_pred) # Remove garbage labels str_true = re.sub(r'[\']+', '', str_true) str_pred = re.sub(r'[\']+', '', str_pred) # Compute WER wer_mean += compute_wer(ref=str_true.split('_'), hyp=str_pred.split('_'), normalize=True) # substitute, insert, delete = wer_align( # ref=str_pred.split('_'), # hyp=str_true.split('_')) # print('SUB: %d' % substitute) # print('INS: %d' % insert) # print('DEL: %d' % delete) # Remove spaces str_true = re.sub(r'[_]+', '', str_true) str_pred = re.sub(r'[_]+', '', str_pred) # Compute CER cer_mean += compute_cer(str_pred=str_pred, str_true=str_true, normalize=True) if progressbar: pbar.update(1) except IndexError: print('skipped') skip_data_num += batch_size_device # TODO: Conduct decoding again with batch size 1 if progressbar: pbar.update(batch_size_device) if is_new_epoch: break cer_mean /= (len(dataset) - skip_data_num) wer_mean /= (len(dataset) - skip_data_num) # TODO: Fix this # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return cer_mean, wer_mean
def check_loading(self, label_type, data_type='dev_clean', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1, num_gpu=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print(' num_gpu: %d' % num_gpu) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( data_type=data_type, train_data_size='train_clean100', label_type=label_type, batch_size=64, max_epoch=1, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True, num_gpu=num_gpu) print('=> Loading mini-batch...') if label_type == 'character': map_file_path = '../../metrics/mapping_files/ctc/character.txt' elif label_type == 'character_capital_divide': map_file_path = '../../metrics/mapping_files/ctc/character_capital.txt' elif label_type == 'word': map_file_path = '../../metrics/mapping_files/ctc/word_' + \ dataset.train_data_size + '.txt' idx2char = Idx2char(map_file_path) idx2word = Idx2word(map_file_path) for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, input_names = data if not self.length_check: for i, l in zip(inputs[0], labels[0]): if len(i) < len(l): raise ValueError( 'input length must be longer than label length.') self.length_check = True if num_gpu > 1: for inputs_gpu in inputs: print(inputs_gpu.shape) if label_type == 'word': if 'test' not in data_type: str_true = ' '.join(idx2word(labels[0][0])) else: word_list = np.delete(labels[0][0], np.where( labels[0][0] == None), axis=0) str_true = ' '.join(word_list) else: str_true = idx2char(labels[0][0]) str_true = re.sub(r'_', ' ', str_true) print('----- %s (epoch: %.3f) -----' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0].shape) print(str_true) if dataset.epoch_detail >= 0.05: break
def decode(session, decode_op, model, dataset, label_type, ss_type, is_test=False, eval_batch_size=None, save_path=None): """Visualize label outputs of CTC model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): kana ss_type (string): remove or insert_left or insert_both or insert_right is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model save_path (string, optional): path to save decoding results """ batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2char = Idx2char(map_file_path='../metrics/mapping_files/' + label_type + '_' + ss_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred_st = session.run(decode_op, feed_dict=feed_dict) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=batch_size) except IndexError: # no output labels_pred = [''] for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if 'char' in label_type: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = idx2char(labels_true[0][i_batch]) str_pred = idx2char(labels_pred[i_batch]) else: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = idx2char(labels_true[0][i_batch]) str_pred = idx2char(labels_pred[i_batch]) print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original
def decode(model, dataset, beam_width, beam_width_sub, eval_batch_size=None, save_path=None, resolving_unk=False): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam in the main task beam_width: (int): the size of beam in the sub task eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results resolving_unk (bool, optional): """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size idx2word = Idx2word(vocab_file_path=dataset.vocab_file_path) idx2char = Idx2char(vocab_file_path=dataset.vocab_file_path_sub, capital_divide=dataset.label_type_sub == 'character_capital_divide') # Read GLM file glm = GLM( glm_path='/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, aw, best_hyps_sub, aw_sub, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, beam_width_sub=beam_width_sub, max_decode_len=MAX_DECODE_LEN_WORD, max_decode_len_sub=MAX_DECODE_LEN_CHAR) else: best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_WORD) best_hyps_sub, aw_sub, _ = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width_sub, max_decode_len=MAX_DECODE_LEN_CHAR, task_index=1) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] ys_sub = batch['ys_sub'][perm_idx] y_lens_sub = batch['y_lens_sub'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref_original = ys[b][0] str_ref_sub = ys_sub[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref_original = idx2word(ys[b][: y_lens[b]]) str_ref_sub = idx2word(ys_sub[b][:y_lens_sub[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2word(best_hyps[b]) str_hyp_sub = idx2char(best_hyps_sub[b]) ############################## # Resolving UNK ############################## if 'OOV' in str_hyp and resolving_unk: str_hyp_no_unk = resolve_unk( str_hyp, best_hyps_sub[b], aw[b], aw_sub[b], idx2char) # if 'OOV' not in str_hyp: # continue ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref_original, glm) str_ref_sub = fix_trans(str_ref_sub, glm) str_hyp = fix_trans(str_hyp, glm) str_hyp_sub = fix_trans(str_hyp_sub, glm) str_hyp_no_unk = fix_trans(str_hyp_no_unk, glm) if len(str_ref) == 0: continue print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref.replace('_', ' ')) print('Hyp (main) : %s' % str_hyp.replace('_', ' ')) print('Hyp (sub) : %s' % str_hyp_sub.replace('_', ' ')) if 'OOV' in str_hyp and resolving_unk: print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' ')) try: # Compute WER wer, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.replace(r'_>.*', '').split('_'), normalize=True) print('WER (main) : %.3f %%' % (wer * 100)) wer_sub, _, _, _ = compute_wer( ref=str_ref_sub.split('_'), hyp=str_hyp_sub.replace(r'>.*', '').split('_'), normalize=True) print('WER (sub) : %.3f %%' % (wer_sub * 100)) if 'OOV' in str_hyp and resolving_unk: wer_no_unk, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp_no_unk.replace( '*', '').replace(r'_>.*', '').split('_'), normalize=True) print('WER (no UNK): %.3f %%' % (wer_no_unk * 100)) except: print('--- skipped ---') if is_new_epoch: break
def do_eval_cer(models, model_type, dataset, label_type, beam_width, max_decode_len, eval_batch_size=None, temperature=1, progressbar=False): """Evaluate trained models by Character Error Rate. Args: models (list): the model to evaluate model_type (string): ctc or attention or hierarchical_ctc or hierarchical_attention dataset: An instance of a `Dataset' class label_type (string): character or character_capital_divide beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction when EOS token have not been emitted. This is used for seq2seq models. eval_batch_size (int, optional): the batch size when evaluating the model temperature (int, optional): progressbar (bool, optional): if True, visualize the progressbar Returns: wer (float): Word error rate cer (float): Character error rate df_wer_cer (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2char = Idx2char( vocab_file_path=dataset.vocab_file_path, capital_divide=(dataset.label_type == 'character_capital_divide')) cer, wer = 0, 0 sub_char, ins_char, del_char = 0, 0, 0 sub_word, ins_word, del_word = 0, 0, 0 num_words, num_chars = 0, 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode the ensemble if model_type in ['attention', 'ctc']: for i, model in enumerate(models): probs_i, perm_idx = model.posteriors(batch['xs'], batch['x_lens'], temperature=temperature) if i == 0: probs = probs_i else: probs += probs_i # NOTE: probs: `[1 (B), T, num_classes]` probs /= len(models) best_hyps = model.decode_from_probs(probs, batch['x_lens'][perm_idx], beam_width=beam_width, max_decode_len=max_decode_len) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] elif model_type in ['hierarchical_attention', 'hierarchical_ctc']: raise NotImplementedError for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = idx2char(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## str_hyp = idx2char(best_hyps[b]) if 'attention' in model.model_type: str_hyp = str_hyp.split('>')[0] # NOTE: Trancate by the first <EOS> # Remove the last space if len(str_hyp) > 0 and str_hyp[-1] == '_': str_hyp = str_hyp[:-1] # Remove consecutive spaces str_hyp = re.sub(r'[_]+', '_', str_hyp) ############################## # Post-proccessing ############################## # Remove garbage labels str_ref = re.sub(r'[\'>]+', '', str_ref) str_hyp = re.sub(r'[\'>]+', '', str_hyp) # TODO: WER計算するときに消していい? # Compute WER wer_b, sub_b, ins_b, del_b = compute_wer(ref=str_ref.split('_'), hyp=str_hyp.split('_'), normalize=False) wer += wer_b sub_word += sub_b ins_word += ins_b del_word += del_b num_words += len(str_ref.split('_')) # Compute CER cer_b, sub_b, ins_b, del_b = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp.replace('_', '')), normalize=False) cer += cer_b sub_char += sub_b ins_char += ins_b del_char += del_b num_chars += len(str_ref.replace('_', '')) if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() wer /= num_words cer /= num_chars sub_char /= num_chars ins_char /= num_chars del_char /= num_chars sub_word /= num_words ins_word /= num_words del_word /= num_words df_wer_cer = pd.DataFrame( { 'SUB': [sub_char * 100, sub_word * 100], 'INS': [ins_char * 100, ins_word * 100], 'DEL': [del_char * 100, del_word * 100] }, columns=['SUB', 'INS', 'DEL'], index=['CER', 'WER']) return cer, wer, df_wer_cer
def decode(model, dataset, eval_batch_size, beam_width, length_penalty, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam length_penalty (float): save_path (string): path to save decoding results """ if 'word' in dataset.label_type: map_fn = Idx2word(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_WORD else: map_fn = Idx2char(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_CHAR if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, _, best_hyps_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=max_decode_len) else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref = map_fn(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = map_fn(best_hyps[b]) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp: %s' % str_hyp.replace('_', ' ')) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = map_fn(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) try: if dataset.label_type == 'word' or dataset.label_type == 'kanji_wb': wer, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=re.sub( r'(.*)[_]*>(.*)', r'\1', str_hyp).split('_'), normalize=True) print('WER: %.3f %%' % (wer * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: wer_ctc, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp_ctc.split('_'), normalize=True) print('WER (CTC): %.3f %%' % (wer_ctc * 100)) else: cer, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list( re.sub(r'(.*)>(.*)', r'\1', str_hyp).replace('_', '')), normalize=True) print('CER: %.3f %%' % (cer * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: cer_ctc, _, _, _ = compute_wer( ref=list(str_ref.replace('_', '')), hyp=list(str_hyp_ctc.replace('_', '')), normalize=True) print('CER (CTC): %.3f %%' % (cer_ctc * 100)) except: print('--- skipped ---') if is_new_epoch: break
def decode(model, dataset, beam_width, eval_batch_size=None, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class beam_width: (int): the size of beam eval_batch_size (int, optional): the batch size when evaluating the model save_path (string): path to save decoding results """ # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size if 'char' in dataset.label_type: map_fn = Idx2char( dataset.vocab_file_path, capital_divide=dataset.label_type == 'character_capital_divide') max_decode_len = MAX_DECODE_LEN_CHAR else: map_fn = Idx2word(dataset.vocab_file_path) max_decode_len = MAX_DECODE_LEN_WORD # Read GLM file glm = GLM( glm_path= '/n/sd8/inaguma/corpus/swbd/data/eval2000/LDC2002T43/reference/en20000405_hub5.glm' ) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode if model.model_type == 'nested_attention': best_hyps, _, best_hyps_sub, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, max_decode_len_sub=max_decode_len) else: best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref_original = ys[b][0] # NOTE: transcript is seperated by space('_') else: # Convert from list of index to string str_ref_original = map_fn(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = map_fn(best_hyps[b]) ############################## # Post-proccessing ############################## str_ref = fix_trans(str_ref_original, glm) str_hyp = fix_trans(str_hyp, glm) if len(str_ref) == 0: continue print('----- wav: %s -----' % batch['input_names'][b]) print('Ref: %s' % str_ref.replace('_', ' ')) print('Hyp: %s' % str_hyp.replace('_', ' ')) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = map_fn(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) try: # Compute WER wer, _, _, _ = compute_wer( ref=str_ref.split('_'), hyp=str_hyp.replace(r'_>.*', '').replace(r'>.*', '').split('_'), normalize=True) print('WER: %.3f %%' % (wer * 100)) if model.ctc_loss_weight > 0: wer_ctc, _, _, _ = compute_wer(ref=str_ref.split('_'), hyp=str_hyp_ctc.split('_'), normalize=True) print('WER (CTC): %.3f %%' % (wer_ctc * 100)) except: print('--- skipped ---') if is_new_epoch: break