def test():
    model = tf.train.import_meta_graph(MODEL_PATH + ".meta")
    graph = tf.get_default_graph()
    inputs = graph.get_operation_by_name('x-input').outputs[0]
    labels = graph.get_operation_by_name('y-input').outputs[0]
    is_train = graph.get_operation_by_name('is_train').outputs[0]
    # tf.get_collection() 返回一个list. 但是这里只要第一个参数即可
    pred = tf.get_collection('pred_network')[0]

    with tf.Session(graph=graph) as sess:
        model.restore(sess, MODEL_PATH)
        # x, y = ds.get_batch_data(TEST_DIR, sess, False, BATCH_SIZE)
        get_flow = OF(sess, "Test", [IMAGE_SIZE, IMAGE_SIZE, 3], BATCH_SIZE)
        next_batch = get_flow.get_batch_data()
        # 取出测试集合
        test_pred_acc = []
        test_label_acc = []
        for i in tqdm(range(NUM_TEST // BATCH_SIZE), "测试中"):
            test_x, test_y = sess.run(next_batch)
            test_label_acc.extend(np.reshape(np.argmax(test_y, 1), [-1]))
            # 使用y进行预测
            pred_y = sess.run(tf.argmax(pred, 1),
                              feed_dict={
                                  inputs: test_x,
                                  labels: test_y,
                                  is_train: [False]
                              })
            test_pred_acc.extend(pred_y)

        test_pred_acc = tf.cast(test_pred_acc, dtype=tf.int32)
        test_label_acc = tf.cast(test_label_acc, dtype=tf.int32)
        pred_acc = evaluation(test_pred_acc, test_label_acc)
        acc = sess.run(pred_acc)
        print("accuracy : ", acc)
Exemple #2
0
def train():
    x = tf.placeholder(tf.float32,
                       [None, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS],
                       name='x-input')
    # 标签
    y_ = tf.placeholder(tf.int64, [None, 1], name='y-input')
    # 是否处于训练状态
    is_train = tf.placeholder(tf.bool, name="is_train")
    # 获取结果
    y = resnet.inference(x, resnet.ResNet_demo['layer_101'], N_CLASSES,
                         is_train)

    loss = losses(y, y_)
    acc = evaluation(y, y_)
    train_step = trainning(loss, LEARNING_RATE_BASE, BATCH_SIZE, 1000)
    with tf.control_dependencies([train_step]):
        train_op = tf.no_op(name='train')
    # TensorFlow持久化类。
    tf.add_to_collection('pred_network', y)
    saver = tf.train.Saver(tf.all_variables(), max_to_keep=50)
    with tf.Session() as sess:
        # 初始化神经网络
        tf.global_variables_initializer().run()
        with tf.device("/cpu:0"):
            get_flow = OF(sess, "Train",
                          [IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS], BATCH_SIZE)
            # 获取训练TenSor
            next_batch = get_flow.get_batch_data()
        if MODEL is not None:
            # 加载模型
            saver.restore(sess, MODEL)
        tf.summary.image("R.G.B", tf.expand_dims(next_batch[0][0], 0))
        merged = tf.summary.merge_all()
        log_summary = tf.summary.FileWriter("log_files", sess.graph)
        # 迭代训练
        for i in range(STEP):
            iamges, labels = sess.run(next_batch)
            _, loss_value, acc_value, merged_value = sess.run(
                [train_op, loss, acc, merged],
                feed_dict={
                    x: iamges,
                    y_: labels,
                    is_train: True
                })
            log_summary.add_summary(merged_value, i)
            if i % LOG_NUM == 0:
                print(
                    "After %d training step(s), loss on training batch is %g."
                    % (i, loss_value), "acc : ", acc_value)
                # 模型存储
                if loss_value is not None and i % SAVE_NUM == 0:
                    save_path = os.path.join(MODEL_PATH, 'mod.ckpt')
                    saver.save(sess, save_path, global_step=i)
        log_summary.close()