def train(): train_file = ['data/train_cad.tfrecord'] valid_file = ['data/val_cad.tfrecord'] train_batch = batched_data(train_file, single_example_parser, CONFIG.batch_size, 10 * CONFIG.batch_size) valid_batch = batched_data(valid_file, single_example_parser, 100, shuffle=False) with tf.Session() as sess: myzfnet = ZFNet(CONFIG) if CONFIG.mode == 'train0': sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) elif CONFIG.mode == 'train1': myzfnet.restore(sess, CONFIG.model_save_path) X_val = sess.run(valid_batch) loss = [] acc = [] for epoch in range(1, CONFIG.epochs + 1): X = sess.run(train_batch) loss_, acc_, prediction_ = myzfnet.train(sess, X[0], X[1]) loss.append(loss_) acc.append(acc_) print(prediction_) print(X[1]) print('>> %d/%d | loss: %f acc: %.2f%%' % (epoch, CONFIG.epochs, loss_, 100.0 * acc_)) if epoch % CONFIG.per_save == 0: acc_val = myzfnet.eval(sess, X_val[0], X_val[1]) print(' acc_val: %.2f%%\n' % (100.0 * acc_val)) myzfnet.save(sess, CONFIG.model_save_path)
def predict(): valid_file = ['data/val_cad.tfrecord'] valid_batch = batched_data(valid_file, single_example_parser, 100, shuffle=False) with tf.Session() as sess: myzfnet = ZFNet(CONFIG) myzfnet.restore(sess, CONFIG.model_save_path) X_val = sess.run(valid_batch) result = myzfnet.predict(sess, X_val[0]) print(result) print( '----------------------------------------------------------------' )