def do_restore(network, label_type, num_stack, num_skip, epoch=None): """Restore model. Args: network: model to restore label_type: phone or character o kanji num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip epoch: epoch to restore """ # Load dataset eval1_data = DataSet(data_type='eval1', label_type=label_type, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) eval2_data = DataSet(data_type='eval2', label_type=label_type, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) eval3_data = DataSet(data_type='eval3', label_type=label_type, num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) # Define model network.define() # Add to the graph each operation decode_op = network.decoder(decode_type='beam_search', beam_width=20) posteriors_op = network.posteriors(decode_op) per_op = network.ler(decode_op) # Create a saver for writing training checkpoints saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(network.model_dir) # If check point exists if ckpt: # Use last saved model model_path = ckpt.model_checkpoint_path if epoch is not None: model_path = model_path.split('/')[:-1] model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch) saver.restore(sess, model_path) print("Model restored: " + model_path) else: raise ValueError('There are not any checkpoints.') if label_type in ['character', 'kanji']: print('■eval1 Evaluation:■') do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval1_data, eval_batch_size=network.batch_size, is_progressbar=True) print('■eval2 Evaluation:■') do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval2_data, eval_batch_size=network.batch_size, is_progressbar=True) print('■eval3 Evaluation:■') do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=eval3_data, eval_batch_size=network.batch_size, is_progressbar=True) else: print('■eval1 Evaluation:■') do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval1_data, eval_batch_size=network.batch_size, is_progressbar=True) print('■eval2 Evaluation:■') do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval2_data, eval_batch_size=network.batch_size, is_progressbar=True) print('■eval3 Evaluation:■') do_eval_per(session=sess, per_op=per_op, network=network, dataset=eval3_data, eval_batch_size=network.batch_size, is_progressbar=True) # Visualize decode_test(session=sess, decode_op=decode_op, network=network, dataset=eval1_data, label_type=label_type)
def do_train(network, optimizer, learning_rate, batch_size, epoch_num, label_type, num_stack, num_skip): """Run training. Args: network: network to train optimizer: string, the name of optimizer. ex.) adam, rmsprop learning_rate: initial learning rate batch_size: size of mini batch epoch_num: epoch num to train label_type: phone39 or phone48 or phone61 or character num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip """ # Load dataset train_data = DataSet(data_type='train', label_type=label_type, num_stack=num_stack, num_skip=num_skip, is_sorted=True) if label_type == 'character': dev_data = DataSet(data_type='dev', label_type='character', num_stack=num_stack, num_skip=num_skip, is_sorted=False) test_data = DataSet(data_type='test', label_type='character', num_stack=num_stack, num_skip=num_skip, is_sorted=False) else: dev_data = DataSet(data_type='dev', label_type='phone39', num_stack=num_stack, num_skip=num_skip, is_sorted=False) test_data = DataSet(data_type='test', label_type='phone39', num_stack=num_stack, num_skip=num_skip, is_sorted=False) # Tell TensorFlow that the model will be built into the default graph with tf.Graph().as_default(): # Define model network.define() # NOTE: define model under tf.Graph() # Add to the graph each operation loss_op = network.loss() train_op = network.train(optimizer=optimizer, learning_rate_init=learning_rate, is_scheduled=False) decode_op = network.decoder(decode_type='beam_search', beam_width=20) per_op = network.ler(decode_op) # Build the summary tensor based on the TensorFlow collection of # summaries summary_train = tf.summary.merge(network.summaries_train) summary_dev = tf.summary.merge(network.summaries_dev) # Add the variable initializer operation init_op = tf.global_variables_initializer() # Create a saver for writing training checkpoints saver = tf.train.Saver(max_to_keep=None) # Count total parameters parameters_dict, total_parameters = count_total_parameters( tf.trainable_variables()) for parameter_name in sorted(parameters_dict.keys()): print("%s %d" % (parameter_name, parameters_dict[parameter_name])) print("Total %d variables, %s M parameters" % (len(parameters_dict.keys()), "{:,}".format(total_parameters / 1000000))) csv_steps = [] csv_train_loss = [] csv_dev_loss = [] # Create a session for running operation on the graph with tf.Session() as sess: # Instantiate a SummaryWriter to output summaries and the graph summary_writer = tf.summary.FileWriter( network.model_dir, sess.graph) # Initialize parameters sess.run(init_op) # Train model iter_per_epoch = int(train_data.data_num / batch_size) if (train_data.data_num / batch_size) != int(train_data.data_num / batch_size): iter_per_epoch += 1 max_steps = iter_per_epoch * epoch_num start_time_train = time.time() start_time_epoch = time.time() start_time_step = time.time() error_best = 1 for step in range(max_steps): # Create feed dictionary for next mini batch (train) inputs, labels, seq_len, _ = train_data.next_batch( batch_size=batch_size) indices, values, dense_shape = list2sparsetensor(labels) feed_dict_train = { network.inputs_pl: inputs, network.label_indices_pl: indices, network.label_values_pl: values, network.label_shape_pl: dense_shape, network.seq_len_pl: seq_len, network.keep_prob_input_pl: network.dropout_ratio_input, network.keep_prob_hidden_pl: network.dropout_ratio_hidden, network.lr_pl: learning_rate } # Create feed dictionary for next mini batch (dev) inputs, labels, seq_len, _ = dev_data.next_batch( batch_size=batch_size) indices, values, dense_shape = list2sparsetensor(labels) feed_dict_dev = { network.inputs_pl: inputs, network.label_indices_pl: indices, network.label_values_pl: values, network.label_shape_pl: dense_shape, network.seq_len_pl: seq_len, network.keep_prob_input_pl: network.dropout_ratio_input, network.keep_prob_hidden_pl: network.dropout_ratio_hidden } # Update parameters & compute loss _, loss_train = sess.run( [train_op, loss_op], feed_dict=feed_dict_train) loss_dev = sess.run(loss_op, feed_dict=feed_dict_dev) csv_steps.append(step) csv_train_loss.append(loss_train) csv_dev_loss.append(loss_dev) if (step + 1) % 10 == 0: # Change feed dict for evaluation feed_dict_train[network.keep_prob_input_pl] = 1.0 feed_dict_train[network.keep_prob_hidden_pl] = 1.0 feed_dict_dev[network.keep_prob_input_pl] = 1.0 feed_dict_dev[network.keep_prob_hidden_pl] = 1.0 # Compute accuracy & update event file ler_train, summary_str_train = sess.run([per_op, summary_train], feed_dict=feed_dict_train) ler_dev, summary_str_dev, labels_st = sess.run([per_op, summary_dev, decode_op], feed_dict=feed_dict_dev) summary_writer.add_summary(summary_str_train, step + 1) summary_writer.add_summary(summary_str_dev, step + 1) summary_writer.flush() duration_step = time.time() - start_time_step print('Step %d: loss = %.3f (%.3f) / ler = %.4f (%.4f) (%.3f min)' % (step + 1, loss_train, loss_dev, ler_train, ler_dev, duration_step / 60)) sys.stdout.flush() start_time_step = time.time() # Save checkpoint and evaluate model per epoch if (step + 1) % iter_per_epoch == 0 or (step + 1) == max_steps: duration_epoch = time.time() - start_time_epoch epoch = (step + 1) // iter_per_epoch print('-----EPOCH:%d (%.3f min)-----' % (epoch, duration_epoch / 60)) # Save model (check point) checkpoint_file = join(network.model_dir, 'model.ckpt') save_path = saver.save( sess, checkpoint_file, global_step=epoch) print("Model saved in file: %s" % save_path) if epoch >= 10: start_time_eval = time.time() if label_type == 'character': print('■Dev Data Evaluation:■') error_epoch = do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=dev_data, eval_batch_size=1) if error_epoch < error_best: error_best = error_epoch print('■■■ ↑Best Score (CER)↑ ■■■') print('■Test Data Evaluation:■') do_eval_cer(session=sess, decode_op=decode_op, network=network, dataset=test_data, eval_batch_size=1) else: print('■Dev Data Evaluation:■') error_epoch = do_eval_per(session=sess, decode_op=decode_op, per_op=per_op, network=network, dataset=dev_data, label_type=label_type, eval_batch_size=1) if error_epoch < error_best: error_best = error_epoch print('■■■ ↑Best Score (PER)↑ ■■■') print('■Test Data Evaluation:■') do_eval_per(session=sess, decode_op=decode_op, per_op=per_op, network=network, dataset=test_data, label_type=label_type, eval_batch_size=1) duration_eval = time.time() - start_time_eval print('Evaluation time: %.3f min' % (duration_eval / 60)) start_time_epoch = time.time() start_time_step = time.time() duration_train = time.time() - start_time_train print('Total time: %.3f hour' % (duration_train / 3600)) # Save train & dev loss save_loss(csv_steps, csv_train_loss, csv_dev_loss, save_path=network.model_dir) # Training was finished correctly with open(join(network.model_dir, 'complete.txt'), 'w') as f: f.write('')
def do_restore(network, label_type, num_stack, num_skip, epoch=None): """Restore model. Args: network: model to restore label_type: phone39 or phone48 or phone61 num_stack: int, the number of frames to stack num_skip: int, the number of frames to skip epoch: epoch to restore """ # Load dataset if label_type == 'character': test_data = DataSet(data_type='test', label_type='character', num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) else: test_data = DataSet(data_type='test', label_type='phone39', num_stack=num_stack, num_skip=num_skip, is_sorted=False, is_progressbar=True) # Define model network.define() # Add to the graph each operation decode_op1, decode_op2 = network.decoder(decode_type='beam_search', beam_width=20) # posteriors_op = network.posteriors(decode_op1) per_op1, per_op2 = network.ler(decode_op1, decode_op2) # Create a saver for writing training checkpoints saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(network.model_dir) # If check point exists if ckpt: # Use last saved model model_path = ckpt.model_checkpoint_path if epoch is not None: model_path = model_path.split('/')[:-1] model_path = '/'.join(model_path) + '/model.ckpt-' + str(epoch) saver.restore(sess, model_path) print("Model restored: " + model_path) else: raise ValueError('There are not any checkpoints.') print('Test Data Evaluation:') do_eval_cer(session=sess, decode_op=decode_op1, network=network, dataset=test_data, is_progressbar=True, is_multitask=True) do_eval_per(session=sess, decode_op=decode_op2, per_op=per_op2, network=network, dataset=test_data, label_type=label_type, is_progressbar=True, is_multitask=True)