def do_train(network, optimizer, learning_rate, batch_size, epoch_num, label_type, num_stack, num_skip): """Run training. If target labels are phone, the model is evaluated by PER with 39 phones. Args: network: network to train optimizer: string, the name of optimizer. ex.) adam, rmsprop learning_rate: A float value, the initial learning rate batch_size: int, the size of mini-batch epoch_num: int, the number of epochs to train label_type: string, phone39 or phone48 or phone61 or character num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip """ # Load dataset train_data = DataSet(data_type='train', label_type=label_type, batch_size=batch_size, num_stack=num_stack, num_skip=num_skip, is_sorted=True) dev_data = DataSet(data_type='dev', label_type=label_type, batch_size=batch_size, num_stack=num_stack, num_skip=num_skip, is_sorted=False) if label_type == 'character': test_data = DataSet(data_type='test', label_type='character', batch_size=1, num_stack=num_stack, num_skip=num_skip, is_sorted=False) else: test_data = DataSet(data_type='test', label_type='phone39', batch_size=1, num_stack=num_stack, num_skip=num_skip, is_sorted=False) # Tell TensorFlow that the model will be built into the default graph with tf.Graph().as_default(): # Define placeholders network.inputs = tf.placeholder(tf.float32, shape=[None, None, network.input_size], name='input') indices_pl = tf.placeholder(tf.int64, name='indices') values_pl = tf.placeholder(tf.int32, name='values') shape_pl = tf.placeholder(tf.int64, name='shape') network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl) network.inputs_seq_len = tf.placeholder(tf.int64, shape=[None], name='inputs_seq_len') network.keep_prob_input = tf.placeholder(tf.float32, name='keep_prob_input') network.keep_prob_hidden = tf.placeholder(tf.float32, name='keep_prob_hidden') # Add to the graph each operation (including model definition) loss_op, logits = network.compute_loss(network.inputs, network.labels, network.inputs_seq_len, network.keep_prob_input, network.keep_prob_hidden) train_op = network.train(loss_op, optimizer=optimizer, learning_rate_init=float(learning_rate), is_scheduled=False) decode_op = network.decoder(logits, network.inputs_seq_len, decode_type='beam_search', beam_width=20) ler_op = network.compute_ler(decode_op, network.labels) # Build the summary tensor based on the TensorFlow collection of # summaries summary_train = tf.summary.merge(network.summaries_train) summary_dev = tf.summary.merge(network.summaries_dev) # Add the variable initializer operation init_op = tf.global_variables_initializer() # Create a saver for writing training checkpoints saver = tf.train.Saver(max_to_keep=None) # Count total parameters parameters_dict, total_parameters = count_total_parameters( tf.trainable_variables()) for parameter_name in sorted(parameters_dict.keys()): print("%s %d" % (parameter_name, parameters_dict[parameter_name])) print("Total %d variables, %s M parameters" % (len(parameters_dict.keys()), "{:,}".format( total_parameters / 1000000))) # Make mini-batch generator mini_batch_train = train_data.next_batch() mini_batch_dev = dev_data.next_batch() csv_steps, csv_loss_train, csv_loss_dev = [], [], [] csv_ler_train, csv_ler_dev = [], [] # Create a session for running operation on the graph with tf.Session() as sess: # Instantiate a SummaryWriter to output summaries and the graph summary_writer = tf.summary.FileWriter(network.model_dir, sess.graph) # Initialize parameters sess.run(init_op) # Train model iter_per_epoch = int(train_data.data_num / batch_size) train_step = train_data.data_num / batch_size if train_step != int(train_step): iter_per_epoch += 1 max_steps = iter_per_epoch * epoch_num start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() error_best = 1 for step in range(max_steps): # Create feed dictionary for next mini batch (train) inputs, labels_st, inputs_seq_len, _ = mini_batch_train.__next__( ) feed_dict_train = { network.inputs: inputs, network.labels: labels_st, network.inputs_seq_len: inputs_seq_len, network.keep_prob_input: network.dropout_ratio_input, network.keep_prob_hidden: network.dropout_ratio_hidden, network.lr: learning_rate } # Create feed dictionary for next mini batch (dev) inputs, labels_st, inputs_seq_len, _ = mini_batch_dev.__next__( ) feed_dict_dev = { network.inputs: inputs, network.labels: labels_st, network.inputs_seq_len: inputs_seq_len, network.keep_prob_input: network.dropout_ratio_input, network.keep_prob_hidden: network.dropout_ratio_hidden } # Update parameters sess.run(train_op, feed_dict=feed_dict_train) if (step + 1) % 10 == 0: # Compute loss loss_train = sess.run(loss_op, feed_dict=feed_dict_train) loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev) csv_steps.append(step) csv_loss_train.append(loss_train) csv_loss_dev.append(loss_dev) # Change to evaluation mode feed_dict_train[network.keep_prob_input] = 1.0 feed_dict_train[network.keep_prob_hidden] = 1.0 feed_dict_dev[network.keep_prob_input] = 1.0 feed_dict_dev[network.keep_prob_hidden] = 1.0 # Compute accuracy & update event file ler_train, summary_str_train = sess.run( [ler_op, summary_train], feed_dict=feed_dict_train) ler_dev, summary_str_dev = sess.run( [ler_op, summary_dev], feed_dict=feed_dict_dev) csv_ler_train.append(ler_train) csv_ler_dev.append(ler_dev) summary_writer.add_summary(summary_str_train, step + 1) summary_writer.add_summary(summary_str_dev, step + 1) summary_writer.flush() duration_step = time.time() - start_time_step print( "Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)" % (step + 1, loss_train, loss_dev, ler_train, ler_dev, duration_step / 60)) sys.stdout.flush() start_time_step = time.time() # Save checkpoint and evaluate model per epoch if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps: duration_epoch = time.time() - start_time_epoch epoch = (step + 1) // iter_per_epoch print('-----EPOCH:%d (%.3f min)-----' % (epoch, duration_epoch / 60)) # Save model (check point) checkpoint_file = join(network.model_dir, 'model.ckpt') save_path = saver.save(sess, checkpoint_file, global_step=epoch) print("Model saved in file: %s" % save_path) if epoch >= 10: start_time_eval = time.time() if label_type == 'character': print('=== Dev Data Evaluation ===') cer_dev_epoch = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=dev_data) print(' CER: %f %%' % (cer_dev_epoch * 100)) if cer_dev_epoch < error_best: error_best = cer_dev_epoch print('■■■ ↑Best Score (CER)↑ ■■■') print('=== Test Data Evaluation ===') cer_test = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=test_data, eval_batch_size=1) print(' CER: %f %%' % (cer_test * 100)) else: print('=== Dev Data Evaluation ===') per_dev_epoch = do_eval_per( session=sess, decode_op=decode_op, per_op=ler_op, network=network, dataset=dev_data, train_label_type=label_type) print(' PER: %f %%' % (per_dev_epoch * 100)) if per_dev_epoch < error_best: error_best = per_dev_epoch print('■■■ ↑Best Score (PER)↑ ■■■') print('=== Test Data Evaluation ===') per_test = do_eval_per( session=sess, decode_op=decode_op, per_op=ler_op, network=network, dataset=test_data, train_label_type=label_type, eval_batch_size=1) print(' PER: %f %%' % (per_test * 100)) duration_eval = time.time() - start_time_eval print('Evaluation time: %.3f min' % (duration_eval / 60)) start_time_epoch = time.time() start_time_step = time.time() duration_train = time.time() - start_time_train print('Total time: %.3f hour' % (duration_train / 3600)) # Save train & dev loss, ler save_loss(csv_steps, csv_loss_train, csv_loss_dev, save_path=network.model_dir) save_ler(csv_steps, csv_ler_train, csv_ler_dev, save_path=network.model_dir) # Training was finished correctly with open(join(network.model_dir, 'complete.txt'), 'w') as f: f.write('')
def do_eval(network, label_type_second, num_stack, num_skip, epoch=None): """Evaluate the model. Args: network: model to restore label_type_second: string, phone39 or phone48 or phone61 num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip epoch: int, the epoch to restore """ # Load dataset if label_type_second == 'character': test_data = DataSet(data_type='test', label_type_second='character', batch_size=1, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) else: test_data = DataSet(data_type='test', label_type_second='phone39', batch_size=1, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) # Define placeholders network.inputs = tf.placeholder(tf.float32, shape=[None, None, network.input_size], name='input') indices_pl = tf.placeholder(tf.int64, name='indices') values_pl = tf.placeholder(tf.int32, name='values') shape_pl = tf.placeholder(tf.int64, name='shape') network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl) indices_second_pl = tf.placeholder(tf.int64, name='indices_second') values_second_pl = tf.placeholder(tf.int32, name='values_second') shape_second_pl = tf.placeholder(tf.int64, name='shape_second') network.labels_second = tf.SparseTensor(indices_second_pl, values_second_pl, shape_second_pl) network.inputs_seq_len = tf.placeholder(tf.int64, shape=[None], name='inputs_seq_len') # Add to the graph each operation _, logits_main, logits_second = network.compute_loss( network.inputs, network.labels, network.labels_second, network.inputs_seq_len) decode_op_main, decode_op_second = network.decoder( logits_main, logits_second, network.inputs_seq_len, decode_type='beam_search', beam_width=20) per_op_main, per_op_second = network.compute_ler(decode_op_main, decode_op_second, network.labels, network.labels_second) # Create a saver for writing training checkpoints saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(network.model_dir) # If check point exists if ckpt: # Use last saved model model_path = ckpt.model_checkpoint_path if epoch is not None: model_path = model_path.split('/')[:-1] model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch) saver.restore(sess, model_path) print("Model restored: " + model_path) else: raise ValueError('There are not any checkpoints.') print('=== Test Data Evaluation ===') cer_test = do_eval_cer(session=sess, decode_op=decode_op_main, network=network, dataset=test_data, is_progressbar=True, is_multitask=True) print(' CER: %f %%' % (cer_test * 100)) per_test = do_eval_per(session=sess, decode_op=decode_op_second, per_op=per_op_second, network=network, dataset=test_data, train_label_type=label_type_second, is_progressbar=True, is_multitask=True) print(' PER: %f %%' % (per_test * 100))
def do_eval(network, label_type, num_stack, num_skip, train_data_size, epoch=None): """Evaluate the model. Args: network: model to restore label_type: string, phone or character o kanji num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip train_data_size: string, default or large epoch: int, the epoch to restore """ # Load dataset eval1_data = DataSet(data_type='eval1', label_type=label_type, batch_size=1, train_data_size=train_data_size, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) eval2_data = DataSet(data_type='eval2', label_type=label_type, batch_size=1, train_data_size=train_data_size, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) eval3_data = DataSet(data_type='eval3', label_type=label_type, batch_size=1, train_data_size=train_data_size, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) # Define placeholders network.inputs = tf.placeholder(tf.float32, shape=[None, None, network.input_size], name='input') indices_pl = tf.placeholder(tf.int64, name='indices') values_pl = tf.placeholder(tf.int32, name='values') shape_pl = tf.placeholder(tf.int64, name='shape') network.labels = tf.SparseTensor(indices_pl, values_pl, shape_pl) network.inputs_seq_len = tf.placeholder(tf.int64, shape=[None], name='inputs_seq_len') network.keep_prob_input = tf.placeholder(tf.float32, name='keep_prob_input') network.keep_prob_hidden = tf.placeholder(tf.float32, name='keep_prob_hidden') # Add to the graph each operation (including model definition) _, logits = network.compute_loss(network.inputs, network.labels, network.inputs_seq_len, network.keep_prob_input, network.keep_prob_hidden) decode_op = network.decoder(logits, network.inputs_seq_len, decode_type='beam_search', beam_width=20) per_op = network.compute_ler(decode_op, network.labels) # Create a saver for writing training checkpoints saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(network.model_dir) # If check point exists if ckpt: # Use last saved model model_path = ckpt.model_checkpoint_path if epoch is not None: model_path = model_path.split('/')[:-1] model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch) saver.restore(sess, model_path) print("Model restored: " + model_path) else: raise ValueError('There are not any checkpoints.') if label_type in ['character', 'kanji']: print('=== eval1 Evaluation ===') cer_eval1 = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval1_data, label_type=label_type, is_test=True, eval_batch_size=1, is_progressbar=True) print(' CER: %f %%' % (cer_eval1 * 100)) print('=== eval2 Evaluation ===') cer_eval2 = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval2_data, label_type=label_type, is_test=True, eval_batch_size=1, is_progressbar=True) print(' CER: %f %%' % (cer_eval2 * 100)) print('=== eval3 Evaluation ===') cer_eval3 = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval3_data, label_type=label_type, is_test=True, eval_batch_size=1, is_progressbar=True) print(' CER: %f %%' % (cer_eval3 * 100)) else: print('=== eval1 Evaluation ===') per_eval1 = do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval1_data, eval_batch_size=1, is_progressbar=True) print(' PER: %f %%' % (per_eval1 * 100)) print('=== eval2 Evaluation ===') per_eval2 = do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval2_data, eval_batch_size=1, is_progressbar=True) print(' PER: %f %%' % (per_eval2 * 100)) print('=== eval3 Evaluation ===') per_eval3 = do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval3_data, eval_batch_size=1, is_progressbar=True) print(' PER: %f %%' % (per_eval3 * 100))