def run_test(model_arch_module, dataset_dir, model_path, result_dir, top_k=1): import data_config as cfg os.makedirs(result_dir, exist_ok=True) model = model_arch_module.build_model_arch() _, _, test_data = data_reader.read_data_set_dir(dataset_dir, cfg.one_hot, 64) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, model_path) overall_truth_indices = [] overall_pred_probs = [] overall_correct_classification = 0. for step in range(test_data.batch_count): images, one_hot = test_data.next_batch() truth_indexes = np.argmax(one_hot, 1) pred = model.predict(sess, images)[0] pred_indexes = np.argmax(pred, 1) if top_k == 1: correct_classification = ( truth_indexes == pred_indexes).astype(np.float32) else: correct_classification = tf.nn.in_top_k(pred, truth_indexes, k=top_k) correct_classification = sess.run( correct_classification).astype(np.float32) overall_truth_indices.extend(truth_indexes) overall_pred_probs.extend(pred) overall_correct_classification = overall_correct_classification + correct_classification.sum( ) print("Step %d, accuracy %f" % (step, correct_classification.mean())) overall_acc = overall_correct_classification / float( test_data.data_set_size) print("Overall accuracy: %f" % overall_acc) with open(os.path.join(result_dir, 'truth_indices.pkl'), 'wb') as f: pickle.dump(overall_truth_indices, f) with open(os.path.join(result_dir, 'pred_probs.pkl'), 'wb') as f: pickle.dump(overall_pred_probs, f)
def run_test_visual(model_arch_module, dataset_dir, model_path): import data_visualizer as dv import data_config as cfg model = model_arch_module.build_model_arch() _, _, test_data = data_reader.read_data_set_dir(dataset_dir, cfg.one_hot, 24) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, model_path) labels = cfg.one_hot_labels for step in range(test_data.batch_count): images, one_hot = test_data.next_batch() truth_indexes = np.argmax(one_hot, 1) pred = model.predict(sess, images)[0] pred_indexes = np.argmax(pred, 1) pred_labels = [labels[i] for i in pred_indexes] dv.show_images_with_truth(images, pred_labels, truth_indexes, pred_indexes)
def run_trainer(model_arch_module, num_epochs, batch_size, dataset_path, model_name, run_name): import tensorflow as tf import data_reader import data_config as cfg model = model_arch_module.build_model_arch() train_data, val_data, _ = data_reader.read_data_set_dir( dataset_path, cfg.one_hot, batch_size) accuracies_input, losses_input, train_mean, val_mean = get_mean_op() global_step = 0 file_writer = tf.summary.FileWriter('logs/%s/' % run_name) saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) file_writer.add_graph(sess.graph) try: for epoch in range(num_epochs): train_accuracies = [] train_losses = [] for step in range(train_data.batch_count): batch_images, batch_labels = train_data.next_batch() if step % 10 == 0: train_accuracy, loss, summary = model.train_step( sess, batch_images, batch_labels, run_summary=True) file_writer.add_summary(summary, global_step) print( 'Epoch %d, step %d, global step %d, training accuracy: %f, training loss %f' % (epoch, step, global_step, train_accuracy, loss)) else: train_accuracy, loss = model.train_step( sess, batch_images, batch_labels, run_summary=False) train_accuracies.append(train_accuracy) train_losses.append(loss) global_step += 1 mean_acc, mean_loss, summ = sess.run(train_mean, feed_dict={ accuracies_input: train_accuracies, losses_input: train_losses }) print('Epoch %d: Training accuracy: %f, loss %f' % (epoch, mean_acc, mean_loss)) file_writer.add_summary(summ, epoch) val_accuracies = [] val_losses = [] for step in range(val_data.batch_count): batch_images, batch_labels = val_data.next_batch() val_accuracy, loss, summary = model.evaluate( sess, batch_images, batch_labels) val_accuracies.append(val_accuracy) val_losses.append(loss) mean_acc, mean_loss, summ = sess.run(val_mean, feed_dict={ accuracies_input: val_accuracies, losses_input: val_losses }) print('Epoch %d: Validation accuracy: %f, loss %f' % (epoch, mean_acc, mean_loss)) file_writer.add_summary(summ, epoch) saver.save(sess, save_path='./model/%s' % model_name, global_step=epoch) except KeyboardInterrupt: print('Training cancelled intentionally.') print('Stop training at %d steps' % global_step) file_writer.close()