def run_test(model_arch_module, dataset_dir, model_path, result_dir, top_k=1):
    import data_config as cfg

    os.makedirs(result_dir, exist_ok=True)

    model = model_arch_module.build_model_arch()
    _, _, test_data = data_reader.read_data_set_dir(dataset_dir, cfg.one_hot,
                                                    64)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_path)

        overall_truth_indices = []
        overall_pred_probs = []
        overall_correct_classification = 0.
        for step in range(test_data.batch_count):
            images, one_hot = test_data.next_batch()
            truth_indexes = np.argmax(one_hot, 1)

            pred = model.predict(sess, images)[0]
            pred_indexes = np.argmax(pred, 1)

            if top_k == 1:
                correct_classification = (
                    truth_indexes == pred_indexes).astype(np.float32)
            else:
                correct_classification = tf.nn.in_top_k(pred,
                                                        truth_indexes,
                                                        k=top_k)
                correct_classification = sess.run(
                    correct_classification).astype(np.float32)

            overall_truth_indices.extend(truth_indexes)
            overall_pred_probs.extend(pred)
            overall_correct_classification = overall_correct_classification + correct_classification.sum(
            )
            print("Step %d, accuracy %f" %
                  (step, correct_classification.mean()))

        overall_acc = overall_correct_classification / float(
            test_data.data_set_size)
        print("Overall accuracy: %f" % overall_acc)

        with open(os.path.join(result_dir, 'truth_indices.pkl'), 'wb') as f:
            pickle.dump(overall_truth_indices, f)
        with open(os.path.join(result_dir, 'pred_probs.pkl'), 'wb') as f:
            pickle.dump(overall_pred_probs, f)
def run_test_visual(model_arch_module, dataset_dir, model_path):
    import data_visualizer as dv
    import data_config as cfg

    model = model_arch_module.build_model_arch()
    _, _, test_data = data_reader.read_data_set_dir(dataset_dir, cfg.one_hot,
                                                    24)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, model_path)

        labels = cfg.one_hot_labels

        for step in range(test_data.batch_count):
            images, one_hot = test_data.next_batch()
            truth_indexes = np.argmax(one_hot, 1)

            pred = model.predict(sess, images)[0]
            pred_indexes = np.argmax(pred, 1)
            pred_labels = [labels[i] for i in pred_indexes]

            dv.show_images_with_truth(images, pred_labels, truth_indexes,
                                      pred_indexes)
Example #3
0
def run_trainer(model_arch_module, num_epochs, batch_size, dataset_path,
                model_name, run_name):

    import tensorflow as tf
    import data_reader
    import data_config as cfg

    model = model_arch_module.build_model_arch()
    train_data, val_data, _ = data_reader.read_data_set_dir(
        dataset_path, cfg.one_hot, batch_size)

    accuracies_input, losses_input, train_mean, val_mean = get_mean_op()

    global_step = 0
    file_writer = tf.summary.FileWriter('logs/%s/' % run_name)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        file_writer.add_graph(sess.graph)

        try:
            for epoch in range(num_epochs):

                train_accuracies = []
                train_losses = []
                for step in range(train_data.batch_count):
                    batch_images, batch_labels = train_data.next_batch()

                    if step % 10 == 0:
                        train_accuracy, loss, summary = model.train_step(
                            sess, batch_images, batch_labels, run_summary=True)
                        file_writer.add_summary(summary, global_step)

                        print(
                            'Epoch %d, step %d, global step %d, training accuracy: %f, training loss %f'
                            % (epoch, step, global_step, train_accuracy, loss))
                    else:
                        train_accuracy, loss = model.train_step(
                            sess,
                            batch_images,
                            batch_labels,
                            run_summary=False)

                    train_accuracies.append(train_accuracy)
                    train_losses.append(loss)
                    global_step += 1

                mean_acc, mean_loss, summ = sess.run(train_mean,
                                                     feed_dict={
                                                         accuracies_input:
                                                         train_accuracies,
                                                         losses_input:
                                                         train_losses
                                                     })
                print('Epoch %d: Training accuracy: %f, loss %f' %
                      (epoch, mean_acc, mean_loss))
                file_writer.add_summary(summ, epoch)

                val_accuracies = []
                val_losses = []
                for step in range(val_data.batch_count):
                    batch_images, batch_labels = val_data.next_batch()

                    val_accuracy, loss, summary = model.evaluate(
                        sess, batch_images, batch_labels)
                    val_accuracies.append(val_accuracy)
                    val_losses.append(loss)

                mean_acc, mean_loss, summ = sess.run(val_mean,
                                                     feed_dict={
                                                         accuracies_input:
                                                         val_accuracies,
                                                         losses_input:
                                                         val_losses
                                                     })
                print('Epoch %d: Validation accuracy: %f, loss %f' %
                      (epoch, mean_acc, mean_loss))
                file_writer.add_summary(summ, epoch)

                saver.save(sess,
                           save_path='./model/%s' % model_name,
                           global_step=epoch)
        except KeyboardInterrupt:
            print('Training cancelled intentionally.')

        print('Stop training at %d steps' % global_step)
        file_writer.close()