예제 #1
0
파일: train.py 프로젝트: hejiangyou/ziti
def train():


    inputs = tf.placeholder(tf.float32, shape=[None, 128, 128, 3])
    outputs = tf.placeholder(tf.float32, shape=[None, class_num])
    tf.summary.image('inputs', inputs, 16)

    lr = tf.placeholder(tf.float32)
    keep_prob = tf.placeholder(tf.float32)

    pred = vgg(inputs, class_num, keep_prob)
    
    with tf.name_scope('cross_entropy'):
        cross_entropy = tf.reduce_mean(-tf.reduce_sum(outputs * tf.log(tf.clip_by_value(pred, 1e-5, 1.0)), reduction_indices=[1]))
        tf.summary.scalar('cross_entropy', cross_entropy)

    with tf.name_scope('accuracy'):
        correct = tf.equal(tf.argmax(pred, 1), tf.argmax(outputs, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))    
        tf.summary.scalar('accuracy', accuracy)

    with tf.name_scope('optimizer'):
        optimizer = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

    merged = tf.summary.merge_all()

    saver = tf.train.Saver()
    with tf.Session() as sess:

        writer = tf.summary.FileWriter('./log/', sess.graph)
        sess.run(tf.global_variables_initializer())
        
        i, stop_count = 0, 0
        st = time.time()
        while True:
            i += 1

            if stop_count == producer_num:
                break

            msg = message.get()
            if msg is None:
                stop_count += 1
                continue

            image, label = msg
            learning_rate = 1e-5 if i < 500 else 1e-5
            sess.run(optimizer, feed_dict={inputs:image, outputs:label, lr:learning_rate, keep_prob:0.5})
            if i % 50 == 0:
             summary, acc, l = sess.run([merged, accuracy, cross_entropy], feed_dict={inputs:image, outputs:label ,keep_prob:1.0})
             print ('iter:{}, acc:{}, loss:{}'.format(i, acc, l))

             writer.add_summary(summary, i)
        print('run time: ', time.time() - st)
        saver.save(sess, './models/vgg.ckpt')
        
        

    return
예제 #2
0
def predict(class_num,path,data):
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # 不使用GPU
        inputs = tf.placeholder(tf.float32, shape=[None, None, 3])
        example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
        example = tf.image.per_image_standardization(example)
        example = tf.expand_dims(example, 0)
        output = vgg(example, class_num, 1.0)
        sess = tf.Session()
        tf.train.Saver().restore(sess, path)
        pred = sess.run(output, feed_dict={inputs: data})
        sess.close()
        return pred
예제 #3
0
def run():

    file_name = u'test.txt'
    # file_name = u'dataset/中国汉字大全.txt'
    texts = read_text(file_name)

    fonts_dir = os.path.join('dataset', 'fonts')
    fonts = [
        os.path.join(os.getcwd(), fonts_dir, path)
        for path in os.listdir(fonts_dir)
    ]

    images_gen = generator_images(texts, fonts)

    inputs = tf.placeholder(tf.float32, shape=[None, None, 3])
    example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
    example = tf.image.per_image_standardization(example)
    example = tf.expand_dims(example, 0)
    outputs = vgg(example, 2, 1.0)

    sess = tf.Session()
    restorer = tf.train.Saver()
    restorer.restore(sess, 'models/vgg.ckpt')

    error = 0
    error_texts = []
    for index, info in enumerate(images_gen):

        image, text = info
        image = np.asarray(image)
        pred = sess.run(outputs, feed_dict={inputs: image})
        pred = np.squeeze(pred)
        label = np.squeeze(np.where(pred == np.max(pred)))
        if index % 2 != label:
            error_texts.append((text, pred.tolist()))
            error += 1

    print 'test num: {}, error num: {}, acc: {}'.format(
        index + 1, error, 1 - float(error) / index)
    show_errors(error_texts, fonts)
예제 #4
0
import tensorflow as tf
from nnets.vgg import vgg

import numpy as np
import tkinter
from tkinter.filedialog import askopenfilename
from tkinter import *
from PIL import Image, ImageFont, ImageDraw

inputs = tf.placeholder(tf.float32, shape=[None, None, 3])
example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
example = tf.image.per_image_standardization(example)
example = tf.expand_dims(example, 0)
output = vgg(example, 9, 1.0)
sess = tf.Session()
tf.train.Saver().restore(sess, 'models/vgg.ckpt')
print("Model restored.")


def selectPath():
    path_ = askopenfilename()
    path.set(path_)


def end():
    root.destroy()


root = Tk()
path = StringVar()
Label(root, text="目标路径:").grid(row=0, column=0)
예제 #5
0
파일: web.py 프로젝트: hejiangyou/ziti
from flask import Flask, render_template, request, redirect, url_for
from werkzeug.utils import secure_filename
import os
import tensorflow as tf
from nnets.vgg import vgg
import numpy as np
from PIL import Image, ImageFont, ImageDraw

app = Flask(__name__)

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # 不使用GPU
inputs = tf.placeholder(tf.float32, shape=[None, None, 3])
example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
example = tf.image.per_image_standardization(example)
example = tf.expand_dims(example, 0)
output = vgg(example, 4, 1.0)
sess = tf.Session()
tf.train.Saver().restore(sess, 'models/vgg.ckpt')
'''导入模型'''


@app.route('/')
def about():
    return redirect(url_for('upload'))


@app.route('/upload', methods=['POST', 'GET'])
def upload():
    if request.method == 'POST':
        f = request.files['file']
        basepath = os.path.dirname(__file__)  # 当前文件所在路径
예제 #6
0
def selectPath():
    path_ = askopenfilename()
    path.set(path_)
    '''将选择的路径结果赋值到path中'''


def end():
    root.destroy()
    '''结束进程'''


inputs = tf.placeholder(tf.float32, shape=[None, None, 3])
example = tf.cast(tf.image.resize_images(inputs, [128, 128]), tf.uint8)
example = tf.image.per_image_standardization(example)
example = tf.expand_dims(example, 0)
output = vgg(example, 5, 1.0)
sess = tf.Session()
tf.train.Saver().restore(sess, 'models/vgg.ckpt')
'''导入模型'''

root = Tk()
root.geometry('570x185+500+400')
'''设置框体的大小'''
path = StringVar()
image_frame = Frame(root)
image_file = im = image_label = None


def create_image_label():
    '''创建图片lable和字体识别lable'''
    data = Image.open(path.get())