def main(_): # Create training directories now = datetime.datetime.now() train_dir_name = now.strftime('vggnet_%Y%m%d_%H%M%S') train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name) checkpoint_dir = os.path.join(train_dir, 'checkpoint') tensorboard_dir = os.path.join(train_dir, 'tensorboard') tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir) if not os.path.isdir(train_dir): os.mkdir(train_dir) if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) # Write flags to txt flags_file_path = os.path.join(train_dir, 'flags.txt') flags_file = open(flags_file_path, 'w') flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate)) flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob)) flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs)) flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) #flags_file.write('train_layers={}\n'.format(FLAGS.train_layers)) flags_file.write('tensorboard_root_dir={}\n'.format( FLAGS.tensorboard_root_dir)) flags_file.write('log_step={}\n'.format(FLAGS.log_step)) flags_file.close() # Placeholders img_size = 256 x = tf.placeholder(tf.float32, [FLAGS.batch_size, img_size, img_size, 3]) y = tf.placeholder(tf.float32, [None, FLAGS.num_classes]) dropout_keep_prob = tf.placeholder(tf.float32) # Model #train_layers = FLAGS.train_layers.split(',') model = VggNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob) loss = model.loss(x, y) #train_op = model.optimize(FLAGS.learning_rate, train_layers) train_op = model.optimize(FLAGS.learning_rate) # Training accuracy of the model correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Summaries tf.summary.scalar('train_loss', loss) tf.summary.scalar('train_accuracy', accuracy) merged_summary = tf.summary.merge_all() train_writer = tf.summary.FileWriter(tensorboard_train_dir) val_writer = tf.summary.FileWriter(tensorboard_val_dir) saver = tf.train.Saver() # Batch preprocessors train_preprocessor = BatchPreprocessor( dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes, output_size=[img_size, img_size], horizontal_flip=True, shuffle=True) val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[img_size, img_size]) # Get the number of training/validation steps per epoch train_batches_per_epoch = np.floor( len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) val_batches_per_epoch = np.floor( len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) train_writer.add_graph(sess.graph) # Directly restore (your model should be exactly the same with checkpoint) # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt") print("{} Start training...".format(datetime.datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.datetime.now(), tensorboard_dir)) for epoch in range(FLAGS.num_epochs): print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1)) step = 1 # Start training while step < train_batches_per_epoch: batch_xs, batch_ys = train_preprocessor.next_batch( FLAGS.batch_size) sess.run(train_op, feed_dict={ x: batch_xs, y: batch_ys, dropout_keep_prob: FLAGS.dropout_keep_prob }) # Logging if step % FLAGS.log_step == 0: s = sess.run(merged_summary, feed_dict={ x: batch_xs, y: batch_ys, dropout_keep_prob: 1. }) train_writer.add_summary( s, epoch * train_batches_per_epoch + step) step += 1 # Epoch completed, start validation print("{} Start validation".format(datetime.datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_preprocessor.next_batch( FLAGS.batch_size, 1) acc = sess.run(accuracy, feed_dict={ x: batch_tx, y: batch_ty, dropout_keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count s = tf.Summary(value=[ tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc) ]) val_writer.add_summary(s, epoch + 1) print("{} Validation Accuracy = {:.4f}".format( datetime.datetime.now(), test_acc)) # Reset the dataset pointers val_preprocessor.reset_pointer() train_preprocessor.reset_pointer() print("{} Saving checkpoint of model...".format( datetime.datetime.now())) #save checkpoint of the model checkpoint_path = os.path.join( checkpoint_dir, 'model_epoch' + str(epoch + 1) + '.ckpt') save_path = saver.save(sess, checkpoint_path) print("{} Model checkpoint saved at {}".format( datetime.datetime.now(), checkpoint_path))
def main(_): # Create training directories now = datetime.datetime.now() train_dir_name = now.strftime('vggnet_%Y%m%d_%H%M%S') train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name) checkpoint_dir = os.path.join(train_dir, 'checkpoint') tensorboard_dir = os.path.join(train_dir, 'tensorboard') tensorboard_train_dir = os.path.join(tensorboard_dir, 'train') tensorboard_val_dir = os.path.join(tensorboard_dir, 'val') if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir) if not os.path.isdir(train_dir): os.mkdir(train_dir) if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir) if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir) if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir) if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir) # Placeholders img_size = 256 x = tf.placeholder(tf.float32, [FLAGS.batch_size, img_size, img_size, 3]) y = tf.placeholder(tf.float32, [None, FLAGS.num_classes]) dropout_keep_prob = tf.placeholder(tf.float32) # Model #train_layers = FLAGS.train_layers.split(',') model = VggNetModel(num_classes=FLAGS.num_classes, dropout_keep_prob=dropout_keep_prob) loss = model.loss(x, y) #train_op = model.optimize(FLAGS.learning_rate, train_layers) train_op = model.optimize(FLAGS.learning_rate) # Training accuracy of the model correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) # Summaries tf.summary.scalar('train_loss', loss) tf.summary.scalar('train_accuracy', accuracy) merged_summary = tf.summary.merge_all() train_writer = tf.summary.FileWriter(tensorboard_train_dir) val_writer = tf.summary.FileWriter(tensorboard_val_dir) saver = tf.train.Saver() # Batch preprocessors val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[img_size, img_size]) # Get the number of training/validation steps per epoch val_batches_per_epoch = np.floor( len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16) test_accuracy = 0 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) train_writer.add_graph(sess.graph) # Directly restore (your model should be exactly the same with checkpoint) saver.restore(sess, FLAGS.ckpt_path) print("{} Start training...".format(datetime.datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.datetime.now(), tensorboard_dir)) for epoch in range(FLAGS.num_epochs): print("{} Epoch number: {}".format(datetime.datetime.now(), epoch + 1)) step = 1 # Epoch completed, start validation print("{} Start Test".format(datetime.datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch): batch_tx, batch_ty = val_preprocessor.next_batch( FLAGS.batch_size, 1) acc = sess.run(accuracy, feed_dict={ x: batch_tx, y: batch_ty, dropout_keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count print("{} Test Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc)) test_accuracy = test_acc # Reset the dataset pointers val_preprocessor.reset_pointer() # Write flags to txt flags_file_path = os.path.join(train_dir, 'flags.txt') flags_file = open(flags_file_path, 'w') flags_file.write('batch_size={}\n'.format(FLAGS.batch_size)) flags_file.write('log_step={}\n'.format(FLAGS.log_step)) flags_file.write('checkpoint_path={}\n'.format(FLAGS.ckpt_path)) flags_file.write('test_accuracy={}'.format(test_accuracy)) flags_file.close()