def do_train(sess, args): # set CPU as the default device for the graph. Some of the operations will be moved to GPU later. with tf.device('/cpu:0'): # Images and labels placeholders # images_ph = tf.placeholder(tf.float32, shape=(None,) + tuple(args.processed_size), name='input') images_ph = tf.placeholder(tf.float32, shape=None, name='input') labels_ph = tf.placeholder(tf.int32, shape=None, name='label') max_seq_len_ph = tf.placeholder(tf.int32, shape=None, name='max_seq_len') label_length_batch_ph = tf.placeholder(tf.int32, shape=None, name='label_length_batch') # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters. is_training_ph = tf.placeholder(tf.bool, name='is_training') # epoch number # 值得一看 epoch_number = tf.get_variable( 'epoch_number', [], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES]) global_step = tf.get_variable( 'global_step', [], dtype=tf.int32, initializer=tf.constant_initializer(0), trainable=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES, SAVE_VARIABLES]) # Weight Decay policy wd = utils.get_policy(args.WD_policy, args.WD_details) # Learning rate decay policy (if needed) # lr = utils.get_policy(args.LR_policy, args.LR_details) # TODO: 可能有问题 lr = 0.0001 # Create an optimizer that performs gradient descent. optimizer = utils.get_optimizer(args.optimizer, lr) # Create a pipeline to read data from disk # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network. # Because we only use 1 GPU for validation, the validation batch size should not be more than 512. batch_size_tf = tf.placeholder_with_default(min(512, args.batch_size), shape=()) # A data loader pipeline to read training images and their labels train_loader = Loader(args.train_info, args.delimiter, args.raw_size, args.processed_size, True, args.chunked_batch_size, args.num_prefetch, args.num_threads, args.path_prefix, args.shuffle) # The loader returns images, their labels, and their paths # images, labels, info = train_loader.load() mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch = train_loader.load( ) # build the computational graph using the provided configuration. dnn_model = model(images_ph, labels_ph, utils.loss, optimizer, wd, args.architecture, args.depth, args.num_chars, args.num_classes, is_training_ph, max_seq_len_ph, label_length_batch_ph, args.transfer_mode, num_gpus=args.num_gpus) # If validation data are provided, we create an input pipeline to load the validation data if args.run_validation: val_loader = Loader(args.val_info, args.delimiter, args.raw_size, args.processed_size, False, batch_size_tf, args.num_prefetch, args.num_threads, args.path_prefix) # TODO: uncomment # val_images, val_labels, val_info = val_loader.load() # Get training operations to run from the deep learning model train_ops = dnn_model.train_ops() # Build an initialization operation to run below. init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init) if args.retrain_from is not None: dnn_model.load(sess, args.retrain_from) # Set the start epoch number start_epoch = sess.run(epoch_number + 1) # Start the queue runners. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # Setup a summary writer summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph) # The main training loop for epoch in range(start_epoch, start_epoch + args.num_epochs): # update epoch_number sess.run(epoch_number.assign(epoch)) print("Epoch %d started" % (epoch)) # Trainig batches for step in range(args.num_batches): sess.run(global_step.assign(step + epoch * args.num_batches)) # train the network on a batch of data (It also measures time) start_time = time.time() # load a batch from input pipeline # img, lbl = sess.run([images, labels], options=args.run_options, run_metadata=args.run_metadata) mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch \ = sess.run([mfcc_feat_batch, label_batch, feat_shape_batch, seq_len_batch, max_seq_len, label_length_batch], options=args.run_options, run_metadata=args.run_metadata) # train on the loaded batch of data _, loss_value, top1_accuracy, topn_accuracy = \ sess.run(train_ops, feed_dict={images_ph: mfcc_feat_batch, labels_ph: label_batch, max_seq_len_ph: max_seq_len, label_length_batch_ph: label_length_batch, is_training_ph: True}, options=args.run_options, run_metadata=args.run_metadata) duration = time.time() - start_time # Check for errors assert not np.isnan( loss_value), 'Model diverged with loss = NaN' # Logging every ten batches and writing tensorboard summaries every hundred batches if step % 10 == 0: num_examples_per_step = args.chunked_batch_size * args.num_gpus examples_per_sec = num_examples_per_step / duration sec_per_batch = duration / args.num_gpus # Log format_str = ( '%s: epoch %d, step %d, loss = %.2f, Top-1 = %.2f Top-' + str(args.top_n) + ' = %.2f (%.1f examples/sec; %.3f sec/batch)') print(format_str % (datetime.now(), epoch, step, loss_value, top1_accuracy, topn_accuracy, examples_per_sec, sec_per_batch)) sys.stdout.flush() if step % 100 == 0: summary_str = sess.run(tf.summary.merge_all(), feed_dict={ images_ph: mfcc_feat_batch, labels_ph: label_batch, max_seq_len_ph: max_seq_len, label_length_batch_ph: label_length_batch, is_training_ph: True }) summary_writer.add_summary(summary_str, args.num_batches * epoch + step) # TODO:这里好像有bug # if args.log_debug_info: # summary_writer.add_run_metadata(run_metadata, 'epoch%d step%d' % (epoch, step)) # Save the model checkpoint periodically after each training epoch checkpoint_path = os.path.join(args.log_dir, args.snapshot_prefix) dnn_model.save(sess, checkpoint_path, global_step=epoch) print("Epoch %d ended. a checkpoint saved at %s" % (epoch, args.log_dir)) sys.stdout.flush() # if validation data are provided, evaluate accuracy on the validation set after the end of each epoch if args.run_validation: print("Evaluating on validation set") """ true_predictions_count = 0 # Counts the number of correct predictions true_topn_predictions_count = 0 # Counts the number of top-n correct predictions total_loss = 0.0 # measures cross entropy loss all_count = 0 # Count the total number of examples # The validation loop for step in range(args.num_val_batches): # Load a batch of data val_img, val_lbl = sess.run([val_images, val_labels], feed_dict={ batch_size_tf: args.num_val_samples % min(512, args.batch_size)} if step == args.num_val_batches - 1 else None, options=args.run_options, run_metadata=args.run_metadata) # validate the network on the loaded batch val_loss, top1_predictions, topn_predictions = sess.run([train_ops[1], train_ops[2], train_ops[3]], feed_dict={images_ph: val_img, labels_ph: val_lbl, is_training_ph: False}, options=args.run_options, run_metadata=args.run_metadata) all_count += val_lbl.shape[0] true_predictions_count += int(round(val_lbl.shape[0] * top1_predictions)) true_topn_predictions_count += int(round(val_lbl.shape[0] * topn_predictions)) total_loss += val_loss * val_lbl.shape[0] if step % 10 == 0: print("Validation step %d of %d" % (step, args.num_val_batches)) sys.stdout.flush() print("Total number of validation examples %d, Loss %.2f, Top-1 Accuracy %.2f, Top-%d Accuracy %.2f" % (all_count, total_loss / all_count, true_predictions_count / all_count, args.top_n, true_topn_predictions_count / all_count)) sys.stdout.flush() """ coord.request_stop() coord.join(threads) sess.close()
def do_evaluate(sess, args): with tf.device('/cpu:0'): # Images and labels placeholders images_ph = tf.placeholder(tf.float32, shape=(None, ) + tuple(args.processed_size), name='input') labels_ph = tf.placeholder(tf.int32, shape=(None), name='label') # a placeholder for determining if we train or validate the network. This placeholder will be used to set dropout rates and batchnorm paramaters. is_training_ph = tf.placeholder(tf.bool, name='is_training') # build a deep learning model using the provided configuration dnn_model = model(images_ph, labels_ph, utils.loss, None, 0.0, args.architecture, args.depth, args.num_chars, args.num_classes, is_training_ph, args.transfer_mode) # creating an input pipeline to read data from disk # a placeholder for setting the input pipeline batch size. This is employed to ensure that we feed each validation example only once to the network. batch_size_tf = tf.placeholder_with_default(args.batch_size, shape=()) # a data loader pipeline to read test data val_loader = Loader(args.val_info, args.delimiter, args.raw_size, args.processed_size, False, batch_size_tf, args.num_prefetch, args.num_threads, args.path_prefix, inference_only=args.inference_only) # if we want to do inference only (i.e. no label is provided) we only load images and their paths if not args.inference_only: val_images, val_labels, val_info = val_loader.load() else: val_images, val_info = val_loader.load() # get evaluation operations from the dnn model eval_ops = dnn_model.evaluate_ops(args.inference_only) # Build an initialization operation to run below. init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init) # Load pretrained parameters from disk dnn_model.load(sess, args.log_dir) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # evaluation if not args.inference_only: true_predictions_count = 0 # Counts the number of correct predictions true_topn_predictions_count = 0 # Counts the number of correct top-n predictions total_loss = 0.0 # Measures cross entropy loss all_count = 0 # Counts the total number of examples # Open an output file to write predictions out_file = open(args.save_predictions, 'w') predictions_format_str = ('%d, %s, %d, %s, %s\n') for step in range(args.num_val_batches): # Load a batch of data val_img, val_lbl, val_inf = sess.run( [val_images, val_labels, val_info], feed_dict={ batch_size_tf: args.num_val_samples % args.batch_size } if step == args.num_val_batches - 1 else None) # Evaluate the network on the loaded batch val_loss, top1_predictions, topn_predictions, topnguesses, topnconf = sess.run( eval_ops, feed_dict={ images_ph: val_img, labels_ph: val_lbl, is_training_ph: False }, options=args.run_options, run_metadata=args.run_metadata) true_predictions_count += np.sum(top1_predictions) true_topn_predictions_count += np.sum(topn_predictions) all_count += top1_predictions.shape[0] total_loss += val_loss * val_lbl.shape[0] print( 'Batch Number: %d, Top-1 Hit: %d, Top-%d Hit: %d, Loss %.2f, Top-1 Accuracy: %.3f, Top-%d Accuracy: %.3f' % (step, true_predictions_count, args.top_n, true_topn_predictions_count, total_loss / all_count, true_predictions_count / all_count, args.top_n, true_topn_predictions_count / all_count)) # log results into an output file for i in range(0, val_inf.shape[0]): out_file.write( predictions_format_str % (step * args.batch_size + i + 1, str( val_inf[i]).encode('utf-8'), val_lbl[i], ', '.join( '%d' % item for item in topnguesses[i]), ', '.join( '%.4f' % item for item in topnconf[i]))) out_file.flush() out_file.close() # inference else: # Open an output file to write predictions out_file = open(args.save_predictions, 'w') predictions_format_str = ('%d, %s, %s, %s\n') for step in range(args.num_val_batches): # Load a batch of data val_img, val_inf = sess.run( [val_images, val_info], feed_dict={ batch_size_tf: args.num_val_samples % args.batch_size } if step == args.num_val_batches - 1 else None) # Run the network on the loaded batch topnguesses, topnconf = sess.run( eval_ops, feed_dict={ images_ph: val_img, is_training_ph: False }, options=args.run_options, run_metadata=args.run_metadata) print('Batch Number: %d of %d is done' % (step, args.num_val_batches)) # Log to an output file for i in range(0, val_inf.shape[0]): out_file.write(predictions_format_str % (step * args.batch_size + i + 1, str(val_inf[i]).encode('utf-8'), ', '.join( '%d' % item for item in topnguesses[i]), ', '.join( '%.4f' % item for item in topnconf[i]))) out_file.flush() out_file.close() coord.request_stop() coord.join(threads) sess.close()