Ejemplo n.º 1
0
def main(_):
    # Create training directories
    now = datetime.datetime.now()
    train_dir_name = now.strftime('vggnet_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.tensorboard_root_dir):
        os.mkdir(FLAGS.tensorboard_root_dir)
    if not os.path.isdir(train_dir): os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir):
        os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
    flags_file.write('dropout_keep_prob={}\n'.format(FLAGS.dropout_keep_prob))
    flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    #flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
    flags_file.write('tensorboard_root_dir={}\n'.format(
        FLAGS.tensorboard_root_dir))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.close()

    # Placeholders
    img_size = 256
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, img_size, img_size, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    #train_layers = FLAGS.train_layers.split(',')
    model = VggNetModel(num_classes=FLAGS.num_classes,
                        dropout_keep_prob=dropout_keep_prob)
    loss = model.loss(x, y)
    #train_op = model.optimize(FLAGS.learning_rate, train_layers)
    train_op = model.optimize(FLAGS.learning_rate)

    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Summaries
    tf.summary.scalar('train_loss', loss)
    tf.summary.scalar('train_accuracy', accuracy)
    merged_summary = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(tensorboard_train_dir)
    val_writer = tf.summary.FileWriter(tensorboard_val_dir)
    saver = tf.train.Saver()

    # Batch preprocessors
    train_preprocessor = BatchPreprocessor(
        dataset_file_path=FLAGS.training_file,
        num_classes=FLAGS.num_classes,
        output_size=[img_size, img_size],
        horizontal_flip=True,
        shuffle=True)
    val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file,
                                         num_classes=FLAGS.num_classes,
                                         output_size=[img_size, img_size])

    # Get the number of training/validation steps per epoch
    train_batches_per_epoch = np.floor(
        len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(
        len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        train_writer.add_graph(sess.graph)

        # Directly restore (your model should be exactly the same with checkpoint)
        # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt")

        print("{} Start training...".format(datetime.datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.datetime.now(), tensorboard_dir))

        for epoch in range(FLAGS.num_epochs):
            print("{} Epoch number: {}".format(datetime.datetime.now(),
                                               epoch + 1))
            step = 1

            # Start training
            while step < train_batches_per_epoch:
                batch_xs, batch_ys = train_preprocessor.next_batch(
                    FLAGS.batch_size)
                sess.run(train_op,
                         feed_dict={
                             x: batch_xs,
                             y: batch_ys,
                             dropout_keep_prob: FLAGS.dropout_keep_prob
                         })

                # Logging
                if step % FLAGS.log_step == 0:
                    s = sess.run(merged_summary,
                                 feed_dict={
                                     x: batch_xs,
                                     y: batch_ys,
                                     dropout_keep_prob: 1.
                                 })
                    train_writer.add_summary(
                        s, epoch * train_batches_per_epoch + step)

                step += 1

            # Epoch completed, start validation
            print("{} Start validation".format(datetime.datetime.now()))
            test_acc = 0.
            test_count = 0

            for _ in range(val_batches_per_epoch):
                batch_tx, batch_ty = val_preprocessor.next_batch(
                    FLAGS.batch_size, 1)
                acc = sess.run(accuracy,
                               feed_dict={
                                   x: batch_tx,
                                   y: batch_ty,
                                   dropout_keep_prob: 1.
                               })
                test_acc += acc
                test_count += 1

            test_acc /= test_count
            s = tf.Summary(value=[
                tf.Summary.Value(tag="validation_accuracy",
                                 simple_value=test_acc)
            ])
            val_writer.add_summary(s, epoch + 1)
            print("{} Validation Accuracy = {:.4f}".format(
                datetime.datetime.now(), test_acc))

            # Reset the dataset pointers
            val_preprocessor.reset_pointer()
            train_preprocessor.reset_pointer()

            print("{} Saving checkpoint of model...".format(
                datetime.datetime.now()))

            #save checkpoint of the model
            checkpoint_path = os.path.join(
                checkpoint_dir, 'model_epoch' + str(epoch + 1) + '.ckpt')
            save_path = saver.save(sess, checkpoint_path)

            print("{} Model checkpoint saved at {}".format(
                datetime.datetime.now(), checkpoint_path))
Ejemplo n.º 2
0
def main(_):
    # Create training directories
    now = datetime.datetime.now()
    train_dir_name = now.strftime('vggnet_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.tensorboard_root_dir):
        os.mkdir(FLAGS.tensorboard_root_dir)
    if not os.path.isdir(train_dir): os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir):
        os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)

    # Placeholders
    img_size = 256
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, img_size, img_size, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    dropout_keep_prob = tf.placeholder(tf.float32)

    # Model
    #train_layers = FLAGS.train_layers.split(',')
    model = VggNetModel(num_classes=FLAGS.num_classes,
                        dropout_keep_prob=dropout_keep_prob)
    loss = model.loss(x, y)
    #train_op = model.optimize(FLAGS.learning_rate, train_layers)
    train_op = model.optimize(FLAGS.learning_rate)

    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Summaries
    tf.summary.scalar('train_loss', loss)
    tf.summary.scalar('train_accuracy', accuracy)
    merged_summary = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(tensorboard_train_dir)
    val_writer = tf.summary.FileWriter(tensorboard_val_dir)
    saver = tf.train.Saver()

    # Batch preprocessors
    val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file,
                                         num_classes=FLAGS.num_classes,
                                         output_size=[img_size, img_size])

    # Get the number of training/validation steps per epoch
    val_batches_per_epoch = np.floor(
        len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)

    test_accuracy = 0

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        train_writer.add_graph(sess.graph)

        # Directly restore (your model should be exactly the same with checkpoint)
        saver.restore(sess, FLAGS.ckpt_path)

        print("{} Start training...".format(datetime.datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.datetime.now(), tensorboard_dir))

        for epoch in range(FLAGS.num_epochs):
            print("{} Epoch number: {}".format(datetime.datetime.now(),
                                               epoch + 1))
            step = 1

            # Epoch completed, start validation
            print("{} Start Test".format(datetime.datetime.now()))
            test_acc = 0.
            test_count = 0

            for _ in range(val_batches_per_epoch):
                batch_tx, batch_ty = val_preprocessor.next_batch(
                    FLAGS.batch_size, 1)
                acc = sess.run(accuracy,
                               feed_dict={
                                   x: batch_tx,
                                   y: batch_ty,
                                   dropout_keep_prob: 1.
                               })
                test_acc += acc
                test_count += 1

            test_acc /= test_count
            print("{} Test Accuracy = {:.4f}".format(datetime.datetime.now(),
                                                     test_acc))
            test_accuracy = test_acc

            # Reset the dataset pointers
            val_preprocessor.reset_pointer()

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.write('checkpoint_path={}\n'.format(FLAGS.ckpt_path))
    flags_file.write('test_accuracy={}'.format(test_accuracy))
    flags_file.close()