Пример #1
0
Файл: train.py Проект: zfxu/HCCR
def inference():
    # build model
    model = cnn_model.model(5, FLAGS.charset_size)

    # 初始化
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint(FLAGS.model_path)
    saver.restore(sess, ckpt)

    # label对应字符
    char_dict = {}
    with open("char_dict2", "rb") as f:
        char_dict = pickle.load(f)

    img = cv2.imread(FLAGS.test_pic, 0)
    img = 255 - img
    img = pre_process(img)
    cv2.imshow("", img)
    cv2.waitKey(0)
    img = np.reshape(img, (1, FLAGS.image_size, FLAGS.image_size, 1))

    feed_dict = {model['images']: img, model['training_or_not']: False}

    out_labels, probs = sess.run([model['index_top_k'], model['val_top_k']],
                                 feed_dict=feed_dict)
    print('-' * 30)
    print('Top k result')
    for i in range(len(out_labels[0])):
        print('%d: Predicted val %s, Probality %3f' %
              (i + 1, char_dict[out_labels[0][i]], probs[0][i]))

    sess.close()
def variableFromId(Id):
    image_list = imageListFromId(Id)
    if len(image_list) == 0:
        a = torch.autograd.Variable(torch.zeros(1, 1, 4096))
        b = maxPool(a)
        print(Id)
        if use_cuda:
            return b.cuda(), True
        else: 
            return b, True


    features_vectors = []
    n = 0
    count = 0
    for img in image_list:
        if n%5 == 0:
            img_tensor = preprocess(img)
            img_tensor.unsqueeze_(0)
            img_variable = Variable(img_tensor, requires_grad=False)
            if use_cuda:
                img_variable = img_variable.cuda(async=True)
            vector = model(img_variable)
            features_vectors.append(vector)
            count += 1
        n += 1
        if count > 4:
            break
    a = torch.stack(features_vectors)
    b = maxPool(a)
    # print(b.size())
    if use_cuda:
        return b.cuda(), False
    else: 
        return b, False
Пример #3
0
def extract_weight():
    model = cnn_model.model(5, 3755)

    # 初始化
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint(MODEL_PATH)
    saver.restore(sess, ckpt)

    v = tf.all_variables()
    d = {}
    for i in range(16):
        a = v[i]
        print(a.name)
        value = sess.run(a)
        d[a.name] = value

    with open("test.bin", 'wb') as f:
        pickle.dump(d, f)

    sess.close()
Пример #4
0
import numpy as np 
np.random.seed(10)
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.reset_default_graph()


(train_images, train_labels), (test_images, test_labels)= tf.keras.datasets.mnist.load_data() 
train_images, train_labels= helper.reshape_and_onehot(train_images, train_labels)
test_images, test_labels= helper.reshape_and_onehot(test_images, test_labels)

with tf.variable_scope('Input'):
    X= tf.placeholder(name= 'x', shape= [None, train_images.shape[1], train_images.shape[2], train_images.shape[3]], dtype= tf.float32)
    Y= tf.placeholder(name= 'y', shape= [None, config.num_classes], dtype= tf.float32)
    output_logits= cnn_model.model(X)

with tf.variable_scope('Train'):

    with tf.variable_scope('Loss'):
        loss= tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels= Y, logits= output_logits), name= 'loss')
    tf.summary.scalar('loss_summary', loss)
    
    with tf.variable_scope('Optimizer'):
        optimizer= tf.train.AdamOptimizer(learning_rate= config.learning_rate, name= 'AdamOpt').minimize(loss)

    with tf.variable_scope('Accuracy'):
        correct_predictions= tf.equal(tf.argmax(output_logits, 1), tf.argmax(Y, 1), name= 'correct_prediction')
        accuracy= (tf.reduce_mean(tf.cast(correct_predictions, tf.float32), name= 'accuracy'))
    tf.summary.scalar('accuracy_summary', accuracy)
Пример #5
0
gridsize = 6
train_b_sets, train_solution_sets = mg_system_data.gen_data(gridsize, 1)
train_b_sets = train_b_sets.reshape((1, gridsize, 1))
train_solution_sets = train_solution_sets.reshape((1, gridsize, 1))
test_b_sets, test_solution_sets = mg_system_data.gen_data(gridsize, 1)
#test_b_sets = test_b_sets.reshape((1, 1, gridsize))
#test_solution_sets = test_solution_sets.reshape((1, 1, gridsize))

print(train_b_sets.shape, test_b_sets.shape)
print('training b set = ', train_b_sets)
print('training u set = ', train_solution_sets)
print('A = ', A.shape)
#print (np.matmul(train_solution_sets, A))
print(train_solution_sets.shape, test_solution_sets.shape)

b, u_, output = cnn_model.model(6)
output = tf.Print(output, [output])

loss = tf.losses.absolute_difference(output, u_)
#optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-4)
optimizer = tf.train.RMSPropOptimizer(learning_rate=1e-4)
train = optimizer.minimize(loss)

correct = tf.equal(output, u_)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

sess = tf.Session()


def network_training(b_vals, actual_u_outputs, epoch_num):
    for e in range(epoch_num):
Пример #6
0
Файл: train.py Проект: zfxu/HCCR
def validation_test():
    val_data_reader = DataReader(FLAGS.val_data_dir, 1, 1, False)
    image_batch, label_batch = val_data_reader.input()

    model = cnn_model.model(5, FLAGS.charset_size)

    # The op for initializing the variables.
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    # 初始化
    sess = tf.Session()
    sess.run(init_op)

    # Start input enqueue threads.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # modele saver
    saver = tf.train.Saver()
    ckpt = tf.train.latest_checkpoint(FLAGS.model_path)
    saver.restore(sess, ckpt)

    # 字典
    char_dict = {}
    with open("char_dict2", "rb") as f:
        char_dict = pickle.load(f)

    # 保存输出结果,错误的情况
    out_file = open('wrong_condition.txt', 'w')

    step = 0
    sum1 = 0
    sumk = 0
    wrong_num = 1
    try:
        while not coord.should_stop():
            # 图像和标签数据
            val_images_batch, val_labels_batch = sess.run(
                [image_batch, label_batch])
            feed_dict = {
                model['images']: val_images_batch,
                model['labels']: val_labels_batch,
                model['training_or_not']: False
            }

            out_labels, probs = sess.run(
                [model['index_top_k'], model['val_top_k']],
                feed_dict=feed_dict)

            if out_labels[0][0] != val_labels_batch[0]:
                out_file.write('true:' + char_dict[val_labels_batch[0]] + '\n')
                out_file.write('topk')
                for i in range(len(out_labels[0])):
                    out_file.write(char_dict[out_labels[0][i]] +
                                   str(probs[0][i]))
                out_file.write('\n')

                val_images_batch[0] = (val_images_batch[0] + 0.5) * 255
                cv2.imwrite('./wrong_pic/' + str(wrong_num) + '.bmp',
                            val_images_batch[0])
                wrong_num += 1

            step += 1
            if step % 100 == 0:
                print('Step %d' % step)

    except tf.errors.OutOfRangeError:
        print('Done!')
    finally:
        # When done, ask the threads to stop.
        coord.request_stop()

    # Wait for threads to finish.
    coord.join(threads)
    sess.close()

    out_file.close()

    print('Average: accuracy = %.4f, top k accuracy = %.4f' %
          (sum1 / step, sumk / step))
Пример #7
0
# coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from cnn_model import model


#1. 准备model
x_data, y_true, y_predict, loss, train_op, accuracy = model()

#1.1 准备数据
mnist = input_data.read_data_sets("./data",one_hot=True)

#2. 开启会话执行
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    image_batch,label_batch = mnist.train.next_batch(batch_size=50)
    for i in range(2000):
        _, _loss, _accuracy = sess.run([train_op,loss,accuracy],feed_dict={x_data:image_batch,y_true:label_batch})
        print("i:",i," loss:",_loss," acc:",_accuracy)