def check(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_type=data_type, label_type=label_type, batch_size=64, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') map_fn = Idx2phone(map_file_path='../../metrics/mapping_files/' + label_type + '.txt') for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, input_names = data if data_type == 'train': for i_batch, l_batch in zip(inputs[0], labels[0]): if len(np.where(l_batch == dataset.padded_value)[0]) > 0: if i_batch.shape[0] < np.where( l_batch == dataset.padded_value)[0][0]: raise ValueError( 'input length must be longer than label length.' ) else: if i_batch.shape[0] < len(l_batch): raise ValueError( 'input length must be longer than label length.' ) str_true = map_fn(labels[0][0]) print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0][0].shape) print(str_true)
def check_loading(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(sort_utt)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset( data_type=data_type, label_type=label_type, batch_size=64, eos_index=1, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') if label_type in ['character', 'character_capital_divide']: map_fn_ctc = Idx2char( map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt') map_fn_att = Idx2char( map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt') else: map_fn_ctc = Idx2phone( map_file_path='../../metrics/mapping_files/ctc/' + label_type + '.txt') map_fn_att = Idx2phone( map_file_path='../../metrics/mapping_files/attention/' + label_type + '.txt') for data, is_new_epoch in dataset: inputs, att_labels, ctc_labels, inputs_seq_len, att_labels_seq_len, input_names = data att_str_true = map_fn_att(att_labels[0][0: att_labels_seq_len[0]]) ctc_str_true = map_fn_ctc(ctc_labels[0]) att_str_true = re.sub(r'_', ' ', att_str_true) ctc_str_true = re.sub(r'_', ' ', ctc_str_true) print('----- %s ----- (epoch: %.3f)' % (input_names[0], dataset.epoch_detail)) print(att_str_true) print(ctc_str_true)
def plot_attention(model, dataset, eval_batch_size, beam_width, length_penalty, save_path=None): """Visualize attention weights of the attetnion-based model. Args: model: model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam length_penalty (float): save_path (string, optional): path to save attention weights plotting """ # Clean directory if save_path is not None and isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) idx2phone = Idx2phone(dataset.vocab_file_path) for batch, is_new_epoch in dataset: # Decode best_hyps, aw, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_PHONE, length_penalty=length_penalty) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space(' ') else: # Convert from list of index to string str_ref = idx2phone(ys[b][:y_lens[b]]) token_list = idx2phone(best_hyps[b]) plot_attention_weights( aw[b][:len(token_list), :batch['x_lens'][b]], label_list=token_list, spectrogram=batch['xs'][b, :, :40], str_ref=str_ref, save_path=join(save_path, batch['input_names'][b] + '.png'), figsize=(20, 8)) if is_new_epoch: break
def check(self, label_type_main, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type_main: %s' % label_type_main) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_type=data_type, label_type_main=label_type_main, label_type_sub='phone61', batch_size=64, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') idx2char = Idx2char(map_file_path='../../metrics/mapping_files/' + label_type_main + '.txt') idx2phone = Idx2phone( map_file_path='../../metrics/mapping_files/phone61.txt') for data, is_new_epoch in dataset: inputs, labels_char, labels_phone, inputs_seq_len, input_names = data if data_type != 'test': str_true_char = idx2char(labels_char[0][0]) str_true_phone = idx2phone(labels_phone[0][0]) else: str_true_char = labels_char[0][0][0] str_true_phone = labels_phone[0][0][0] print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(str_true_char) print(str_true_phone)
def check(self, label_type, data_type='dev', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') map_file_path = '../../metrics/mapping_files/' + label_type + '.txt' num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_type=data_type, label_type=label_type, batch_size=64, map_file_path=map_file_path, max_epoch=1, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, progressbar=True) print('=> Loading mini-batch...') map_fn = Idx2phone(map_file_path=map_file_path) for data, is_new_epoch in dataset: inputs, labels, inputs_seq_len, labels_seq_len, input_names = data str_true = map_fn(labels[0][0][0:labels_seq_len[0][0]]) print('----- %s ----- (epoch: %.3f)' % (input_names[0][0], dataset.epoch_detail)) print(inputs[0][0].shape) print(str_true)
def check(self, encoder_type, lstm_impl, time_major=False): print('==================================================') print(' encoder_type: %s' % str(encoder_type)) print(' lstm_impl: %s' % str(lstm_impl)) print(' time_major: %s' % str(time_major)) print('==================================================') tf.reset_default_graph() with tf.Graph().as_default(): # Load batch data batch_size = 2 inputs, labels_char, labels_phone, inputs_seq_len = generate_data( label_type='multitask', model='ctc', batch_size=batch_size) # Define model graph num_classes_main = 27 num_classes_sub = 61 model = MultitaskCTC( encoder_type=encoder_type, input_size=inputs[0].shape[1], num_units=256, num_layers_main=2, num_layers_sub=1, num_classes_main=num_classes_main, num_classes_sub=num_classes_sub, main_task_weight=0.8, lstm_impl=lstm_impl, parameter_init=0.1, clip_grad_norm=5.0, clip_activation=50, num_proj=256, weight_decay=1e-8, # bottleneck_dim=50, bottleneck_dim=None, time_major=time_major) # Define placeholders model.create_placeholders() learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate') # Add to the graph each operation loss_op, logits_main, logits_sub = model.compute_loss( model.inputs_pl_list[0], model.labels_pl_list[0], model.labels_sub_pl_list[0], model.inputs_seq_len_pl_list[0], model.keep_prob_pl_list[0]) train_op = model.train(loss_op, optimizer='adam', learning_rate=learning_rate_pl) decode_op_main, decode_op_sub = model.decoder( logits_main, logits_sub, model.inputs_seq_len_pl_list[0], beam_width=20) ler_op_main, ler_op_sub = model.compute_ler( decode_op_main, decode_op_sub, model.labels_pl_list[0], model.labels_sub_pl_list[0]) # Define learning rate controller learning_rate = 1e-3 lr_controller = Controller(learning_rate_init=learning_rate, decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=5, lower_better=True) # Add the variable initializer operation init_op = tf.global_variables_initializer() # Count total parameters parameters_dict, total_parameters = count_total_parameters( tf.trainable_variables()) for parameter_name in sorted(parameters_dict.keys()): print("%s %d" % (parameter_name, parameters_dict[parameter_name])) print("Total %d variables, %s M parameters" % (len(parameters_dict.keys()), "{:,}".format( total_parameters / 1000000))) # Make feed dict feed_dict = { model.inputs_pl_list[0]: inputs, model.labels_pl_list[0]: list2sparsetensor(labels_char, padded_value=-1), model.labels_sub_pl_list[0]: list2sparsetensor(labels_phone, padded_value=-1), model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_pl_list[0]: 0.9, learning_rate_pl: learning_rate } idx2phone = Idx2phone(map_file_path='./phone61.txt') with tf.Session() as sess: # Initialize parameters sess.run(init_op) # Wrapper for tfdbg # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # Train model max_steps = 1000 start_time_step = time.time() for step in range(max_steps): # Compute loss _, loss_train = sess.run([train_op, loss_op], feed_dict=feed_dict) # Gradient check # grads = sess.run(model.clipped_grads, # feed_dict=feed_dict) # for grad in grads: # print(np.max(grad)) if (step + 1) % 10 == 0: # Change to evaluation mode feed_dict[model.keep_prob_pl_list[0]] = 1.0 # Compute accuracy ler_train_char, ler_train_phone = sess.run( [ler_op_main, ler_op_sub], feed_dict=feed_dict) duration_step = time.time() - start_time_step print( 'Step %d: loss = %.3f / cer = %.3f / per = %.3f (%.3f sec) / lr = %.5f' % (step + 1, loss_train, ler_train_char, ler_train_phone, duration_step, learning_rate)) start_time_step = time.time() # Visualize labels_pred_char_st, labels_pred_phone_st = sess.run( [decode_op_main, decode_op_sub], feed_dict=feed_dict) labels_pred_char = sparsetensor2list( labels_pred_char_st, batch_size=batch_size) labels_pred_phone = sparsetensor2list( labels_pred_phone_st, batch_size=batch_size) print('Character') try: print(' Ref: %s' % idx2alpha(labels_char[0])) print(' Hyp: %s' % idx2alpha(labels_pred_char[0])) except IndexError: print('Character') print(' Ref: %s' % idx2alpha(labels_char[0])) print(' Hyp: %s' % '') print('Phone') try: print(' Ref: %s' % idx2phone(labels_phone[0])) print(' Hyp: %s' % idx2phone(labels_pred_phone[0])) except IndexError: print(' Ref: %s' % idx2phone(labels_phone[0])) print(' Hyp: %s' % '') # NOTE: This is for no prediction print('-' * 30) if ler_train_char < 0.1: print('Modle is Converged.') break # Update learning rate learning_rate = lr_controller.decay_lr( learning_rate=learning_rate, epoch=step, value=ler_train_char) feed_dict[learning_rate_pl] = learning_rate
def decode(session, decode_op_main, decode_op_sub, model, dataset, label_type_main, label_type_sub, is_test=True, save_path=None): """Visualize label outputs of Multi-task CTC model. Args: session: session of training model decode_op_main: operation for decoding in the main task decode_op_sub: operation for decoding in the sub task model: the model to evaluate dataset: An instance of a `Dataset` class label_type_main (string): character or character_capital_divide label_type_sub (string): phone39 or phone48 or phone61 is_test (bool, optional): save_path (string, optional): path to save decoding results """ idx2char = Idx2char(map_file_path='../metrics/mapping_files/' + label_type_main + '.txt') idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type_sub + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true_char, labels_true_phone, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred_char_st, labels_pred_phone_st = session.run( [decode_op_main, decode_op_sub], feed_dict=feed_dict) try: labels_pred_char = sparsetensor2list(labels_pred_char_st, batch_size=batch_size) except: # no output labels_pred_char = [''] try: labels_pred_phone = sparsetensor2list(labels_pred_char_st, batch_size=batch_size) except: # no output labels_pred_phone = [''] for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if is_test: str_true_char = labels_true_char[0][i_batch][0].replace( '_', ' ') str_true_phone = labels_true_phone[0][i_batch][0] else: str_true_char = idx2char(labels_true_char[0][i_batch]) str_true_phone = idx2phone(labels_true_phone[0][i_batch]) str_pred_char = idx2char(labels_pred_char[i_batch]) str_pred_phone = idx2phone(labels_pred_phone[i_batch]) print('Ref (char): %s' % str_true_char) print('Hyp (char): %s' % str_pred_char) print('Ref (phone): %s' % str_true_phone) print('Hyp (phone): %s' % str_pred_phone) if is_new_epoch: break
def check_training(self, encoder_type, label_type, lstm_impl='LSTMBlockCell', save_params=False): print('==================================================') print(' encoder_type: %s' % encoder_type) print(' label_type: %s' % label_type) print(' lstm_impl: %s' % lstm_impl) print('==================================================') tf.reset_default_graph() with tf.Graph().as_default(): # Load batch data batch_size = 1 splice = 11 if encoder_type in [ 'vgg_blstm', 'vgg_lstm', 'vgg_wang', 'resnet_wang', 'cnn_zhang' ] else 1 inputs, labels_true_st, inputs_seq_len = generate_data( label_type=label_type, model='ctc', batch_size=batch_size, splice=splice) # NOTE: input_size must be even number when using CudnnLSTM # Define model graph num_classes = 26 if label_type == 'character' else 61 model = CTC( encoder_type=encoder_type, input_size=inputs[0].shape[-1] // splice, splice=splice, num_units=256, num_layers=2, num_classes=num_classes, lstm_impl=lstm_impl, parameter_init=0.1, clip_grad=5.0, clip_activation=50, num_proj=256, # bottleneck_dim=50, bottleneck_dim=None, weight_decay=1e-8) # Define placeholders model.create_placeholders() learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate') # Add to the graph each operation loss_op, logits = model.compute_loss( model.inputs_pl_list[0], model.labels_pl_list[0], model.inputs_seq_len_pl_list[0], model.keep_prob_input_pl_list[0], model.keep_prob_hidden_pl_list[0], model.keep_prob_output_pl_list[0]) train_op = model.train(loss_op, optimizer='adam', learning_rate=learning_rate_pl) # NOTE: Adam does not run on CudnnLSTM decode_op = model.decoder(logits, model.inputs_seq_len_pl_list[0], beam_width=20) ler_op = model.compute_ler(decode_op, model.labels_pl_list[0]) # Define learning rate controller learning_rate = 1e-3 lr_controller = Controller(learning_rate_init=learning_rate, decay_start_epoch=10, decay_rate=0.98, decay_patient_epoch=5, lower_better=True) if save_params: # Create a saver for writing training checkpoints saver = tf.train.Saver(max_to_keep=None) # Add the variable initializer operation init_op = tf.global_variables_initializer() # Count total parameters if lstm_impl != 'CudnnLSTM': parameters_dict, total_parameters = count_total_parameters( tf.trainable_variables()) for parameter_name in sorted(parameters_dict.keys()): print("%s %d" % (parameter_name, parameters_dict[parameter_name])) print("Total %d variables, %s M parameters" % (len(parameters_dict.keys()), "{:,}".format( total_parameters / 1000000))) # Make feed dict feed_dict = { model.inputs_pl_list[0]: inputs, model.labels_pl_list[0]: labels_true_st, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0, learning_rate_pl: learning_rate } idx2phone = Idx2phone(map_file_path='./phone61_ctc.txt') with tf.Session() as sess: # Initialize parameters sess.run(init_op) # Wrapper for tfdbg # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # Train model max_steps = 1000 start_time_global = time.time() start_time_step = time.time() ler_train_pre = 1 not_improved_count = 0 for step in range(max_steps): # Compute loss _, loss_train = sess.run([train_op, loss_op], feed_dict=feed_dict) # Gradient check # grads = sess.run(model.clipped_grads, # feed_dict=feed_dict) # for grad in grads: # print(np.max(grad)) if (step + 1) % 10 == 0: # Change to evaluation mode feed_dict[model.keep_prob_input_pl_list[0]] = 1.0 feed_dict[model.keep_prob_hidden_pl_list[0]] = 1.0 feed_dict[model.keep_prob_output_pl_list[0]] = 1.0 # Compute accuracy ler_train = sess.run(ler_op, feed_dict=feed_dict) duration_step = time.time() - start_time_step print( 'Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' % (step + 1, loss_train, ler_train, duration_step, learning_rate)) start_time_step = time.time() # Decode labels_pred_st = sess.run(decode_op, feed_dict=feed_dict) labels_true = sparsetensor2list(labels_true_st, batch_size=batch_size) # Visualize try: labels_pred = sparsetensor2list( labels_pred_st, batch_size=batch_size) if label_type == 'character': print('Ref: %s' % idx2alpha(labels_true[0])) print('Hyp: %s' % idx2alpha(labels_pred[0])) else: print('Ref: %s' % idx2phone(labels_true[0])) print('Hyp: %s' % idx2phone(labels_pred[0])) except IndexError: if label_type == 'character': print('Ref: %s' % idx2alpha(labels_true[0])) print('Hyp: %s' % '') else: print('Ref: %s' % idx2phone(labels_true[0])) print('Hyp: %s' % '') # NOTE: This is for no prediction if ler_train >= ler_train_pre: not_improved_count += 1 else: not_improved_count = 0 if ler_train < 0.05: print('Modle is Converged.') if save_params: # Save model (check point) checkpoint_file = './model.ckpt' save_path = saver.save(sess, checkpoint_file, global_step=1) print("Model saved in file: %s" % save_path) break ler_train_pre = ler_train # Update learning rate learning_rate = lr_controller.decay_lr( learning_rate=learning_rate, epoch=step, value=ler_train) feed_dict[learning_rate_pl] = learning_rate duration_global = time.time() - start_time_global print('Total time: %.3f sec' % (duration_global))
def decode(session, decode_op, model, dataset, label_type, is_test=False, save_path=None): """Visualize label outputs of Attention-based model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): phone39 or phone48 or phone61 or character or character_capital_divide is_test (bool, optional): save_path (string): path to save decoding results """ if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, labels_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_encoder_pl_list[0]: 1.0, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred = session.run(decode_op, feed_dict=feed_dict) for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn( labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1]) # NOTE: Exclude <SOS> and <EOS> str_pred = map_fn(labels_pred[i_batch]).split('>')[0] # NOTE: Trancate by <EOS> if 'phone' in label_type: # Remove the last space if str_pred[-1] == ' ': str_pred = str_pred[:-1] print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break
def decode(session, decode_op, model, dataset, label_type, is_test=True, save_path=None): """Visualize label outputs of CTC model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): phone39 or phone48 or phone61 or character or character_capital_divide is_test (bool, optional): save_path (string, optional): path to save decoding results """ if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] labels_pred_st = session.run(decode_op, feed_dict=feed_dict) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=batch_size) except IndexError: # no output labels_pred = [''] for i_batch in range(batch_size): print('----- wav: %s -----' % input_names[0][i_batch]) if 'char' in label_type: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn(labels_true[0][i_batch]) str_pred = map_fn(labels_pred[i_batch]) else: if is_test: str_true = labels_true[0][i_batch][0] else: str_true = map_fn(labels_true[0][i_batch]) str_pred = map_fn(labels_pred[i_batch]) print('Ref: %s' % str_true) print('Hyp: %s' % str_pred) if is_new_epoch: break
def do_eval_per(session, decode_op, per_op, model, dataset, label_type, is_test=False, eval_batch_size=None, progressbar=False, is_multitask=False, is_jointctcatt=False): """Evaluate trained model by Phone Error Rate. Args: session: session of training model decode_op: operation for decoding per_op: operation for computing phone error rate model: the model to evaluate dataset: An instance of a `Dataset' class label_type (string): phone39 or phone48 or phone61 is_test (bool, optional): set to True when evaluating by the test set eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model is_jointctcatt (bool, optional): if True, evaluate the joint CTC-Attention model Returns: per_mean (float): An average of PER """ batch_size_original = dataset.batch_size # Reset data counter dataset.reset() # Set batch size in the evaluation if eval_batch_size is not None: dataset.batch_size = eval_batch_size train_label_type = label_type eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type idx2phone_train = Idx2phone( map_file_path='../metrics/mapping_files/' + train_label_type + '.txt') idx2phone_eval = Idx2phone( map_file_path='../metrics/mapping_files/' + eval_label_type + '.txt') map2phone39_train = Map2phone39( label_type=train_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') map2phone39_eval = Map2phone39( label_type=eval_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') per_mean = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini-batch if is_multitask: inputs, _, labels_true, inputs_seq_len, labels_seq_len, _ = data elif is_jointctcatt: inputs, labels_true, _, inputs_seq_len, labels_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, labels_seq_len, _ = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_encoder_pl_list[0]: 1.0, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0 } batch_size = inputs[0].shape[0] # Evaluate by 39 phones labels_pred = session.run(decode_op, feed_dict=feed_dict) for i_batch in range(batch_size): ############### # Hypothesis ############### # Convert from index to phone (-> list of phone strings) str_pred = idx2phone_train(labels_pred[i_batch]).split('>')[0] # NOTE: Trancate by <EOS> # Remove the last space if str_pred[-1] == ' ': str_pred = str_pred[:-1] phone_pred_list = str_pred.split(' ') ############### # Reference ############### if is_test: phone_true_list = labels_true[0][i_batch][0].split(' ') else: # Convert from index to phone (-> list of phone strings) phone_true_list = idx2phone_eval( labels_true[0][i_batch][1:labels_seq_len[0][i_batch] - 1]).split(' ') # NOTE: Exclude <SOS> and <EOS> # Mapping to 39 phones (-> list of phone strings) phone_pred_list = map2phone39_train(phone_pred_list) phone_true_list = map2phone39_eval(phone_true_list) # Compute PER per_mean += compute_per(ref=phone_true_list, hyp=phone_pred_list, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break per_mean /= len(dataset) # Register original batch size if eval_batch_size is not None: dataset.batch_size = batch_size_original return per_mean
def plot(session, decode_op, attention_weights_op, model, dataset, label_type, is_test=False, save_path=None, show=False): """Visualize attention weights of Attetnion-based model. Args: session: session of training model decode_op: operation for decoding attention_weights_op: operation for computing attention weights model: model to evaluate dataset: An instance of a `Dataset` class label_type (string, optional): phone39 or phone48 or phone61 or character or character_capital_divide is_test (bool, optional): save_path (string, optional): path to save attention weights plotting show (bool, optional): if True, show each figure """ # Clean directory if save_path is not None and isdir(save_path): shutil.rmtree(save_path) mkdir(save_path) if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/' + label_type + '.txt') for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch inputs, labels_true, inputs_seq_len, _, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs[0], model.inputs_seq_len_pl_list[0]: inputs_seq_len[0], model.keep_prob_encoder_pl_list[0]: 1.0, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0 } # Visualize batch_size, max_frame_num = inputs.shape[:2] attention_weights, labels_pred = session.run( [attention_weights_op, decode_op], feed_dict=feed_dict) for i_batch in range(batch_size): # t_out, t_in = attention_weights[i_batch].shape # Check if the sum of attention weights equals to 1 # print(np.sum(attention_weights[i_batch], axis=1)) # Convert from index to label str_pred = map_fn(labels_pred[i_batch]) if 'phone' in label_type: label_list = str_pred.split(' ') else: raise NotImplementedError plt.clf() plt.figure(figsize=(10, 4)) sns.heatmap(attention_weights[i_batch], cmap='Blues', xticklabels=False, yticklabels=label_list) plt.xlabel('Input frames', fontsize=12) plt.ylabel('Output labels (top to bottom)', fontsize=12) if show: plt.show() # Save as a png file if save_path is not None: plt.savefig(join(save_path, input_names[0] + '.png'), dvi=500) if is_new_epoch: break
def decode_test_multitask(session, decode_op_main, decode_op_sub, model, dataset, label_type_main, label_type_sub, save_path=None): """Visualize label outputs of Multi-task CTC model. Args: session: session of training model decode_op_main: operation for decoding in the main task decode_op_sub: operation for decoding in the sub task model: the model to evaluate dataset: An instance of a `Dataset` class label_type_main (string): character or character_capital_divide label_type_sub (string): phone39 or phone48 or phone61 save_path (string, optional): path to save decoding results """ # TODO: fix if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') # Decode character print('===== ' + label_type_main + ' =====') idx2char = Idx2char(map_file_path='../metrics/mapping_files/ctc/' + label_type_main + '.txt') while True: # Create feed dictionary for next mini batch data, is_new_epoch = dataset.next(batch_size=1) inputs, labels_true, _, inputs_seq_len, input_names = data # NOTE: Batch size is expected to be 1 feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } # Visualize labels_pred_st = session.run(decode_op_main, feed_dict=feed_dict) labels_pred = sparsetensor2list(labels_pred_st, batch_size=1) print('----- wav: %s -----' % input_names[0]) print('Ref: %s' % idx2char(labels_true[0])) print('Hyp: %s' % idx2char(labels_pred[0])) if is_new_epoch: break # Decode phone print('\n===== ' + label_type_sub + ' =====') idx2phone = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' + label_type_sub + '.txt') while True: # Create feed dictionary for next mini batch data, is_new_epoch = dataset.next(batch_size=1) inputs, _, labels_true, inputs_seq_len, input_names = data feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } # Visualize labels_pred_st = session.run(decode_op_sub, feed_dict=feed_dict) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=1) except IndexError: # no output labels_pred = [''] finally: print('----- wav: %s -----' % input_names[0]) print('Ref: %s' % idx2phone(labels_true[0])) print('Hyp: %s' % idx2phone(labels_pred[0])) if is_new_epoch: break
def decode_test(session, decode_op, model, dataset, label_type, save_path=None): """Visualize label outputs of CTC model. Args: session: session of training model decode_op: operation for decoding model: the model to evaluate dataset: An instance of a `Dataset` class label_type (string): phone39 or phone48 or phone61 or character or character_capital_divide save_path (string, optional): path to save decoding results """ if label_type == 'character': map_fn = Idx2char( map_file_path='../metrics/mapping_files/ctc/character.txt') elif label_type == 'character_capital_divide': map_fn = Idx2char( map_file_path= '../metrics/mapping_files/ctc/character_capital_divide.txt', capital_divide=True) else: map_fn = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' + label_type + '.txt') if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') while True: # Create feed dictionary for next mini batch data, is_new_epoch = dataset.next(batch_size=1) inputs, labels_true, inputs_seq_len, input_names = data # NOTE: Batch size is expected to be 1 feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } # Visualize labels_pred_st = session.run(decode_op, feed_dict=feed_dict) try: labels_pred = sparsetensor2list(labels_pred_st, batch_size=1) except IndexError: # no output labels_pred = [''] finally: print('----- wav: %s -----' % input_names[0]) if label_type == 'character': true_seq = map_fn(labels_true[0]).replace('_', ' ') pred_seq = map_fn(labels_pred[0]).replace('_', ' ') else: true_seq = map_fn(labels_true[0]) pred_seq = map_fn(labels_pred[0]) print('Ref: %s' % true_seq) print('Hyp: %s' % pred_seq) if is_new_epoch: break
def do_eval_per(session, decode_op, per_op, model, dataset, label_type, eval_batch_size=None, progressbar=False, is_multitask=False): """Evaluate trained model by Phone Error Rate. Args: session: session of training model decode_op: operation for decoding per_op: operation for computing phone error rate model: the model to evaluate dataset: An instance of a `Dataset' class label_type (string): phone39 or phone48 or phone61 eval_batch_size (int, optional): the batch size when evaluating the model progressbar (bool, optional): if True, visualize the progressbar is_multitask (bool, optional): if True, evaluate the multitask model Returns: per_mean (float): An average of PER """ # Reset data counter dataset.reset() train_label_type = label_type eval_label_type = dataset.label_type_sub if is_multitask else dataset.label_type # phone2idx_39_map_file_path = '../metrics/mapping_files/ctc/phone39.txt' idx2phone_train = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' + train_label_type + '.txt') idx2phone_eval = Idx2phone(map_file_path='../metrics/mapping_files/ctc/' + eval_label_type + '.txt') map2phone39_train = Map2phone39( label_type=train_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') map2phone39_eval = Map2phone39( label_type=eval_label_type, map_file_path='../metrics/mapping_files/phone2phone.txt') per_mean = 0 if progressbar: pbar = tqdm(total=len(dataset)) for data, is_new_epoch in dataset: # Create feed dictionary for next mini batch if is_multitask: inputs, _, labels_true, inputs_seq_len, _ = data else: inputs, labels_true, inputs_seq_len, _ = data feed_dict = { model.inputs_pl_list[0]: inputs, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.keep_prob_input_pl_list[0]: 1.0, model.keep_prob_hidden_pl_list[0]: 1.0, model.keep_prob_output_pl_list[0]: 1.0 } batch_size_each = len(inputs) # Evaluate by 39 phones labels_pred_st = session.run(decode_op, feed_dict=feed_dict) labels_pred = sparsetensor2list(labels_pred_st, batch_size_each) for i_batch in range(batch_size_each): ############### # Hypothesis ############### # Convert from index to phone (-> list of phone strings) phone_pred_list = idx2phone_train(labels_pred[i_batch]).split(' ') # Mapping to 39 phones (-> list of phone strings) phone_pred_list = map2phone39_train(phone_pred_list) ############### # Reference ############### # Convert from index to phone (-> list of phone strings) phone_true_list = idx2phone_eval(labels_true[i_batch]).split(' ') # Mapping to 39 phones (-> list of phone strings) phone_true_list = map2phone39_eval(phone_true_list) # Compute PER per_mean += compute_per(ref=phone_pred_list, hyp=phone_true_list, normalize=True) if progressbar: pbar.update(1) if is_new_epoch: break per_mean /= len(dataset) return per_mean
def check(self, label_type, data_type='dev', backend='pytorch', shuffle=False, sort_utt=False, sort_stop_epoch=None, frame_stacking=False, splice=1): print('========================================') print(' backend: %s' % backend) print(' label_type: %s' % label_type) print(' data_type: %s' % data_type) print(' shuffle: %s' % str(shuffle)) print(' sort_utt: %s' % str(sort_utt)) print(' sort_stop_epoch: %s' % str(sort_stop_epoch)) print(' frame_stacking: %s' % str(frame_stacking)) print(' splice: %d' % splice) print('========================================') num_stack = 3 if frame_stacking else 1 num_skip = 3 if frame_stacking else 1 dataset = Dataset(data_save_path='/n/sd8/inaguma/corpus/timit/kaldi', backend=backend, input_freq=41, use_delta=True, use_double_delta=True, data_type=data_type, label_type=label_type, batch_size=64, max_epoch=2, splice=splice, num_stack=num_stack, num_skip=num_skip, shuffle=shuffle, sort_utt=sort_utt, sort_stop_epoch=sort_stop_epoch, tool='htk', num_enque=None) print('=> Loading mini-batch...') idx2phone = Idx2phone(dataset.vocab_file_path) for batch, is_new_epoch in dataset: if data_type == 'train' and backend == 'pytorch': for i in range(len(batch['xs'])): if batch['xs'].shape[1] < batch['ys'].shape[1]: raise ValueError( 'input length must be longer than label length.') if dataset.is_test: str_true = batch['ys'][0][0] else: str_true = idx2phone(batch['ys'][0][:batch['y_lens'][0]]) print('----- %s (epoch: %.3f, batch: %d) -----' % (batch['input_names'][0], dataset.epoch_detail, len(batch['xs']))) print(str_true) print('x_lens: %d' % (batch['x_lens'][0] * num_stack)) if not dataset.is_test: print('y_lens: %d' % batch['y_lens'][0])
def __init__(self, data_save_path, backend, input_freq, use_delta, use_double_delta, data_type, label_type, batch_size, max_epoch=None, splice=1, num_stack=1, num_skip=1, shuffle=False, sort_utt=False, reverse=False, sort_stop_epoch=None, tool='htk', num_enque=None, dynamic_batching=False): """A class for loading dataset. Args: data_save_path (string): path to saved data backend (string): pytorch or chainer input_freq (int): the number of dimensions of acoustics use_delta (bool): if True, use the delta feature use_double_delta (bool): if True, use the acceleration feature data_type (string): train or dev or test label_type (string): phone39 or phone48 or phone61 batch_size (int): the size of mini-batch max_epoch (int): the max epoch. None means infinite loop. splice (int): frames to splice. Default is 1 frame. num_stack (int): the number of frames to stack num_skip (int): the number of frames to skip shuffle (bool): if True, shuffle utterances. This is disabled when sort_utt is True. sort_utt (bool): if True, sort all utterances in the ascending order reverse (bool): if True, sort utteraces in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order tool (string): htk or librosa or python_speech_features num_enque (int): the number of elements to enqueue dynamic_batching (bool): if True, batch size will be chainged dynamically in training """ self.backend = backend self.input_freq = input_freq self.use_delta = use_delta self.use_double_delta = use_double_delta self.data_type = data_type self.label_type = label_type self.batch_size = batch_size self.max_epoch = max_epoch self.splice = splice self.num_stack = num_stack self.num_skip = num_skip self.shuffle = shuffle self.sort_utt = sort_utt self.sort_stop_epoch = sort_stop_epoch self.num_gpus = 1 self.tool = tool self.num_enque = num_enque self.dynamic_batching = dynamic_batching self.is_test = True if data_type == 'test' else False self.vocab_file_path = join(data_save_path, 'vocab', label_type + '.txt') self.idx2phone = Idx2phone(self.vocab_file_path) super(Dataset, self).__init__(vocab_file_path=self.vocab_file_path) # Load dataset file dataset_path = join(data_save_path, 'dataset', tool, data_type, label_type + '.csv') df = pd.read_csv(dataset_path) df = df.loc[:, ['frame_num', 'input_path', 'transcript']] # Sort paths to input & label if sort_utt: df = df.sort_values(by='frame_num', ascending=not reverse) else: df = df.sort_values(by='input_path', ascending=True) self.df = df self.rest = set(list(df.index))
def eval_phone(model, dataset, map_file_path, eval_batch_size, beam_width, max_decode_len, length_penalty=0, progressbar=False): """Evaluate trained model by Phone Error Rate. Args: model: the model to evaluate dataset: An instance of a `Dataset' class map_file_path (string): path to phones.60-48-39.map eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam max_decode_len (int): the length of output sequences to stop prediction. This is used for seq2seq models. length_penalty (float, optional): progressbar (bool, optional): if True, visualize the progressbar Returns: per (float): Phone error rate df_per (pd.DataFrame): dataframe of substitution, insertion, and deletion """ # Reset data counter dataset.reset() idx2phone = Idx2phone(vocab_file_path=dataset.vocab_file_path) map2phone39 = Map2phone39(label_type=dataset.label_type, map_file_path=map_file_path) per = 0 sub, ins, dele = 0, 0, 0 num_phones = 0 if progressbar: pbar = tqdm(total=len(dataset)) # TODO: fix this while True: batch, is_new_epoch = dataset.next(batch_size=eval_batch_size) # Decode best_hyps, _, perm_idx = model.decode(batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=max_decode_len, length_penalty=length_penalty) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: phone_ref_list = ys[b][0].split(' ') # NOTE: transcript is seperated by space(' ') else: # Convert from index to phone (-> list of phone strings) phone_ref_list = idx2phone(ys[b][:y_lens[b]]).split(' ') ############################## # Hypothesis ############################## # Convert from index to phone (-> list of phone strings) str_hyp = idx2phone(best_hyps[b]) str_hyp = re.sub(r'(.*) >(.*)', r'\1', str_hyp) # NOTE: Trancate by the first <EOS> phone_hyp_list = str_hyp.split(' ') # Mapping to 39 phones (-> list of phone strings) if dataset.label_type != 'phone39': phone_ref_list = map2phone39(phone_ref_list) phone_hyp_list = map2phone39(phone_hyp_list) # Compute PER try: per_b, sub_b, ins_b, del_b = compute_wer(ref=phone_ref_list, hyp=phone_hyp_list, normalize=False) per += per_b sub += sub_b ins += ins_b dele += del_b num_phones += len(phone_ref_list) except: pass if progressbar: pbar.update(1) if is_new_epoch: break if progressbar: pbar.close() # Reset data counters dataset.reset() per /= num_phones sub /= num_phones ins /= num_phones dele /= num_phones df_per = pd.DataFrame( { 'SUB': [sub * 100], 'INS': [ins * 100], 'DEL': [dele * 100] }, columns=['SUB', 'INS', 'DEL'], index=['PER']) return per, df_per
def check(self, encoder_type, attention_type, label_type='character'): print('==================================================') print(' encoder_type: %s' % encoder_type) print(' attention_type: %s' % attention_type) print(' label_type: %s' % label_type) print('==================================================') tf.reset_default_graph() with tf.Graph().as_default(): # Load batch data batch_size = 4 inputs, labels, inputs_seq_len, labels_seq_len = generate_data( label_type=label_type, model='attention', batch_size=batch_size) # Define model graph num_classes = 27 if label_type == 'character' else 61 model = AttentionSeq2Seq(input_size=inputs[0].shape[1], encoder_type=encoder_type, encoder_num_units=256, encoder_num_layers=2, encoder_num_proj=None, attention_type=attention_type, attention_dim=128, decoder_type='lstm', decoder_num_units=256, decoder_num_layers=1, embedding_dim=64, num_classes=num_classes, sos_index=num_classes, eos_index=num_classes + 1, max_decode_length=100, use_peephole=True, splice=1, parameter_init=0.1, clip_grad_norm=5.0, clip_activation_encoder=50, clip_activation_decoder=50, weight_decay=1e-8, time_major=True, sharpening_factor=1.0, logits_temperature=1.0) # Define placeholders model.create_placeholders() learning_rate_pl = tf.placeholder(tf.float32, name='learning_rate') # Add to the graph each operation loss_op, logits, decoder_outputs_train, decoder_outputs_infer = model.compute_loss( model.inputs_pl_list[0], model.labels_pl_list[0], model.inputs_seq_len_pl_list[0], model.labels_seq_len_pl_list[0], model.keep_prob_encoder_pl_list[0], model.keep_prob_decoder_pl_list[0], model.keep_prob_embedding_pl_list[0]) train_op = model.train(loss_op, optimizer='adam', learning_rate=learning_rate_pl) decode_op_train, decode_op_infer = model.decode( decoder_outputs_train, decoder_outputs_infer) ler_op = model.compute_ler(model.labels_st_true_pl, model.labels_st_pred_pl) # Define learning rate controller learning_rate = 1e-3 lr_controller = Controller(learning_rate_init=learning_rate, decay_start_epoch=20, decay_rate=0.9, decay_patient_epoch=10, lower_better=True) # Add the variable initializer operation init_op = tf.global_variables_initializer() # Count total parameters parameters_dict, total_parameters = count_total_parameters( tf.trainable_variables()) for parameter_name in sorted(parameters_dict.keys()): print("%s %d" % (parameter_name, parameters_dict[parameter_name])) print("Total %d variables, %s M parameters" % (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000))) # Make feed dict feed_dict = { model.inputs_pl_list[0]: inputs, model.labels_pl_list[0]: labels, model.inputs_seq_len_pl_list[0]: inputs_seq_len, model.labels_seq_len_pl_list[0]: labels_seq_len, model.keep_prob_encoder_pl_list[0]: 0.8, model.keep_prob_decoder_pl_list[0]: 1.0, model.keep_prob_embedding_pl_list[0]: 1.0, learning_rate_pl: learning_rate } idx2phone = Idx2phone(map_file_path='./phone61.txt') with tf.Session() as sess: # Initialize parameters sess.run(init_op) # Wrapper for tfdbg # sess = tf_debug.LocalCLIDebugWrapperSession(sess) # Train model max_steps = 1000 start_time_step = time.time() for step in range(max_steps): # Compute loss _, loss_train = sess.run( [train_op, loss_op], feed_dict=feed_dict) # Gradient check # grads = sess.run(model.clipped_grads, # feed_dict=feed_dict) # for grad in grads: # print(np.max(grad)) if (step + 1) % 10 == 0: # Change to evaluation mode feed_dict[model.keep_prob_encoder_pl_list[0]] = 1.0 feed_dict[model.keep_prob_decoder_pl_list[0]] = 1.0 feed_dict[model.keep_prob_embedding_pl_list[0]] = 1.0 # Predict class ids predicted_ids_train, predicted_ids_infer = sess.run( [decode_op_train, decode_op_infer], feed_dict=feed_dict) # Compute accuracy try: feed_dict_ler = { model.labels_st_true_pl: list2sparsetensor( labels, padded_value=model.eos_index), model.labels_st_pred_pl: list2sparsetensor( predicted_ids_infer, padded_value=model.eos_index) } ler_train = sess.run( ler_op, feed_dict=feed_dict_ler) except IndexError: ler_train = 1 duration_step = time.time() - start_time_step print('Step %d: loss = %.3f / ler = %.3f (%.3f sec) / lr = %.5f' % (step + 1, loss_train, ler_train, duration_step, learning_rate)) start_time_step = time.time() # Visualize if label_type == 'character': print('True : %s' % idx2alpha(labels[0])) print('Pred (Training) : <%s' % idx2alpha(predicted_ids_train[0])) print('Pred (Inference): <%s' % idx2alpha(predicted_ids_infer[0])) else: print('True : %s' % idx2phone(labels[0])) print('Pred (Training) : < %s' % idx2phone(predicted_ids_train[0])) print('Pred (Inference): < %s' % idx2phone(predicted_ids_infer[0])) if ler_train < 0.1: print('Model is Converged.') break # Update learning rate learning_rate = lr_controller.decay_lr( learning_rate=learning_rate, epoch=step, value=ler_train) feed_dict[learning_rate_pl] = learning_rate
def decode(model, dataset, eval_batch_size, beam_width, length_penalty, save_path=None): """Visualize label outputs. Args: model: the model to evaluate dataset: An instance of a `Dataset` class eval_batch_size (int): the batch size when evaluating the model beam_width: (int): the size of beam length_penalty (float): save_path (string): path to save decoding results """ idx2phone = Idx2phone(dataset.vocab_file_path) if save_path is not None: sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w') for batch, is_new_epoch in dataset: # Decode best_hyps, _, perm_idx = model.decode( batch['xs'], batch['x_lens'], beam_width=beam_width, max_decode_len=MAX_DECODE_LEN_PHONE, length_penalty=length_penalty) if model.model_type == 'attention' and model.ctc_loss_weight > 0: best_hyps_ctc, perm_idx = model.decode_ctc(batch['xs'], batch['x_lens'], beam_width=beam_width) ys = batch['ys'][perm_idx] y_lens = batch['y_lens'][perm_idx] for b in range(len(batch['xs'])): ############################## # Reference ############################## if dataset.is_test: str_ref = ys[b][0] # NOTE: transcript is seperated by space(' ') else: # Convert from list of index to string str_ref = idx2phone(ys[b][:y_lens[b]]) ############################## # Hypothesis ############################## # Convert from list of index to string str_hyp = idx2phone(best_hyps[b]) print('----- wav: %s -----' % batch['input_names'][b]) print('Ref : %s' % str_ref) print('Hyp : %s' % str_hyp) if model.model_type == 'attention' and model.ctc_loss_weight > 0: str_hyp_ctc = idx2phone(best_hyps_ctc[b]) print('Hyp (CTC): %s' % str_hyp_ctc) # Compute PER per, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=re.sub(r'(.*) >(.*)', r'\1', str_hyp).split(' '), normalize=True) print('PER: %.3f %%' % (per * 100)) if model.model_type == 'attention' and model.ctc_loss_weight > 0: per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '), hyp=str_hyp_ctc.split(' '), normalize=True) print('PER (CTC): %.3f %%' % (per_ctc * 100)) if is_new_epoch: break