Esempio n. 1
0
def test(name, words, files):
    class_count = len(files)
    x, y, y_ = th.setup((len(words)), class_count)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess, name)

        print("Model restored")
        lines, inputs = load.load_test_data(words)

        for i in range(len(lines)):
            line = lines[i]
            inp = inputs[i]
            result = sess.run(y, feed_dict={x: inp})
            m = 0.0
            class_index = 0
            for i in range(len(result[0])):
                r = result[0][i]
                if r > m:
                    m = r
                    class_index = i
            print("Test data " + line + " : " + files[class_index] + " : " + str(result) + " : ")
Esempio n. 2
0
import sys

print("Init..")
model_name = "model.ckpt"
epochs = 2
if "e" in sys.argv:
    epochs = int(sys.argv[sys.argv.index("e")+1])
files = []
files.append("mute")
files.append("volume")
files.append("channel")
print("Files: " + ", ".join(files))

print("Loading data..")
inputs, outputs, words = load.load_data(files)

if "t" in sys.argv:
    print("Setup train..")
    sess = tf.InteractiveSession()
    x, y, y_ = th.setup(len(words), len(files))
    train_step, writer, merged, accuracy = th.trainSetup(y, y_, sess)

    print("Train..")
    th.train(inputs, outputs, x, y_, train_step, sess, epochs, writer, merged, accuracy)

    print("Save..")
    th.save(sess, model_name)
else:
    print("Test..")
    test.test(model_name, words, files)