def run_training(data): with tf.Graph().as_default(): global_step = tf.Variable(0, name='global_step', trainable=False) images_pl = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 32, 32, 3]) labels_pl = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) logits = graph.inference(images_pl) loss = graph.loss(logits, labels_pl) train_op = graph.train(loss, global_step) eval_correct = graph.evaluate(logits, labels_pl) summary_op = tf.merge_all_summaries() saver = tf.train.Saver(tf.all_variables()) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) summary_writer = tf.train.SummaryWriter(SUMMARY_DIR, sess.graph) for step in range(N_EPOCH * (DS_SIZE // BATCH_SIZE)): start_time = time.time() feed_dict = fill_feed_dict(data.train, images_pl, labels_pl) _, loss_val = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time assert not np.isnan(loss_val), 'Model diverged with loss = NaN' if step % 10 == 0 or step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_val, duration)) if step > 0: summary_str = sess.run(summary_op, feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if step > 0: if step < 1000 and step % 200 == 0: print('Training Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.train) print('Validation Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.validation) if step % 1000 == 0 or step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: print('Training Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.train) print('Validation Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.validation) if step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: print('Test Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.test) # Save the model checkpoint periodically. if step % 1000 == 0 or step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: checkpoint_path = CHECKPOINT_DIR saver.save(sess, checkpoint_path, global_step=step)
def run_training(data): with tf.Graph().as_default(): images_pl = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 512, 512, 3]) labels_pl = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) logits = graph.inference(images_pl) loss = graph.loss(logits, labels_pl) train_op = graph.train(loss, 0.0001) eval_correct = graph.evaluate(logits, labels_pl) saver = tf.train.Saver(tf.trainable_variables()) summary_op = tf.merge_all_summaries() init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) # c,d = data.train.next_batch(BATCH_SIZE) # a = sess.run(logits,feed_dict={images_pl: c}) # print a summary_writer = tf.train.SummaryWriter("summary", sess.graph) for step in range(N_EPOCH * (DS_SIZE // BATCH_SIZE)): start_time = time.time() feed_dict = fill_feed_dict(data.train, images_pl, labels_pl) _, loss_val = sess.run([train_op, loss], feed_dict=feed_dict) duration = time.time() - start_time assert not np.isnan(loss_val), 'Model diverged with loss = NaN' if step % 10 == 0 or step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_val, duration)) if step > 0: summary_str = sess.run(summary_op, feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush() if step % 100 == 0 or step == N_EPOCH * (DS_SIZE // BATCH_SIZE) - 1: save_path = saver.save(sess, "model.ckpt") print("Model saved in file: %s" % save_path) print('Training Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.train) print('Validation Data Eval:') do_eval(sess, eval_correct, images_pl, labels_pl, data.validation)
def run_evaluation(data): with tf.Graph().as_default(): images_pl = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 32, 32, 3]) labels_pl = tf.placeholder(tf.int32, shape=[BATCH_SIZE]) logits = graph.inference(images_pl) eval_correct = graph.evaluate(logits, labels_pl) saver = tf.train.Saver(tf.all_variables()) # init = tf.initialize_all_variables() sess = tf.Session() saver.restore(sess, "checkpoints/-4500") print("Model restored.") # sess.run(init) do_eval(sess, eval_correct, images_pl, labels_pl, data.test)