/
DigitClassifier.py
108 lines (93 loc) · 3.87 KB
/
DigitClassifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# -*- coding:utf-8 -*-
from NeuralNetwork import NeuralNetwork
from InputCanvas import InputCanvas
from PIL import Image, ImageFilter
from numpy import random
import numpy as np
import tkinter
float_formatter = lambda x: "%.4f" % x
np.set_printoptions(formatter={'float_kind': float_formatter})
class DigitClassifier(tkinter.Tk):
def __init__(self):
tkinter.Tk.__init__(self)
self.nn = NeuralNetwork(784, 300, 10)
self.background = tkinter.Canvas(self, width = 308, height = 308)
self.background.config(background="black")
self.input_canvas = InputCanvas(self, width = 300, height = 300)
self.result_label = tkinter.Label(self, text='')
self.recog_button = tkinter.Button(self, text='Recognize', command=self.recognize)
self.clear_button = tkinter.Button(self, text='Clear', command=self.input_canvas.clear)
self.background.pack()
self.input_canvas.place(x=4, y=4)
self.result_label.pack()
self.recog_button.pack()
self.clear_button.pack()
def train_nn(self, epochs=100000, edit_image=False):
"""ニューラルネットワークを訓練する"""
import Mnist
labels = Mnist.trainLabels
images = Mnist.trainImages
inputs, targets = [], []
for _ in range(epochs):
i = int(random.random() * len(labels))
target = np.zeros(10)
if edit_image:
# 訓練データを加工する
img = Image.fromarray(images[i])
new_img = Image.new('L', (28, 28))
new_img.paste(img.rotate(random.uniform(-45.0, 45.0)),
(random.randint(-5.0, 5.0), random.randint(-5.0, 5.0)))
image = np.asarray(new_img).ravel()
else:
# 加工なし
image = images[i].ravel()
inputs.append(image/255.0)
target[labels[i]] = 1.0
targets.append(target)
print("start training...")
self.nn.train(np.array(inputs), np.array(targets), n=0.01)
labels = Mnist.testLabels
images = Mnist.testImages
inputs, targets = [], []
for i in range(len(labels)):
target = np.zeros(10)
inputs.append(images[i].ravel() / 255.0)
target[labels[i]] = 1.0
targets.append(target)
print("start testing...")
results = self.nn.test(np.array(inputs), np.array(targets))
#print(results)
overall = np.zeros((10, 10), dtype=int)
correct = 0
for result, target in zip(results, targets):
ri = max(enumerate(result), key=lambda x: x[1])[0]
ti = max(enumerate(target), key=lambda x: x[1])[0]
overall[ti, ri] += 1
if ti == ri:
correct += 1
print(overall)
print(float(correct)/len(labels))
# 訓練後のパラメータを保存する
np.save('parameters/w1_2.npy', self.nn.w1_2)
np.save('parameters/w2_3.npy', self.nn.w2_3)
def load_nn_parameters(self):
# パラメータを読み込む
self.nn.w1_2 = np.load('parameters/w1_2.npy')
self.nn.w2_3 = np.load('parameters/w2_3.npy')
def recognize(self):
# キャンバスに書き込まれた数字を認識する
img = self.input_canvas.getImage().filter(ImageFilter.BLUR).convert('L')
img.thumbnail((28, 28), getattr(Image, 'ANTIALIAS'))
img = img.point(lambda x: 255 - x)
input = np.asarray(img).ravel()
result = self.nn.test([input / 255.0], np.zeros(10))[0]
num = max(enumerate(result), key=lambda x: x[1])[0]
self.result_label.configure(text = str(num))
print(num, result)
def main():
root = DigitClassifier()
root.load_nn_parameters()
#root.train_nn(edit_image=False)
root.mainloop()
if __name__ == '__main__':
main()