/
main.py
89 lines (74 loc) · 2.19 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import mnist
import numpy as np
from conv_layer import ConvLayer
from max_pool_layer import MaxPoolLayer
from softmax_layer import SoftmaxLayer
train_images = mnist.train_images()[:1000]
train_labels = mnist.train_labels()[:1000]
test_images = mnist.test_images()[:1000]
test_labels = mnist.test_labels()[:1000]
conv = ConvLayer(8)
pool = MaxPoolLayer()
softmax = SoftmaxLayer(13 * 13 * 8, 10)
def check(image):
out = conv.forward((image / 255) - 0.5)
out = pool.forward(out)
out = softmax.forward(out)
return np.argmax(out)
def forward(image, label):
'''
Completes a forward pass of the CNN and calculates the accuracy and
cross-entropy loss.
- image is a 2d numpy array
- label is a digit
'''
# We transform the image from [0, 255] to [-0.5, 0.5] to make it easier
# to work with. This is standard practice.
out = conv.forward((image / 255) - 0.5)
out = pool.forward(out)
out = softmax.forward(out)
# Calculate cross-entropy loss and accuracy. np.log() is the natural log.
loss = -np.log(out[label])
acc = 1 if np.argmax(out) == label else 0
return out, loss, acc
def train(im, label, lr=.005):
'''
Completes a full training step on the given image and label.
Returns the cross-entropy loss and accuracy.
- image is a 2d numpy array
- label is a digit
- lr is the learning rate
'''
# Forward
out, loss, acc = forward(im, label)
# Backprop
gradient = softmax.backprop(label, lr)
gradient = pool.backprop(gradient)
gradient = conv.backprop(gradient, lr)
return loss, acc
print('MNIST CNN initialized!')
# Train!
loss = 0
num_correct = 0
for i, (im, label) in enumerate(zip(train_images, train_labels)):
if i > 0 and i % 100 == 99:
print(
'[Step %d] Past 100 steps: Average Loss %.3f | Accuracy: %d%%' %
(i + 1, loss / 100, num_correct)
)
loss = 0
num_correct = 0
l, acc = train(im, label)
loss += l
num_correct += acc
# Test the CNN
print('\n--- Testing the CNN ---')
loss = 0
num_correct = 0
for im, label in zip(test_images, test_labels):
_, l, acc = forward(im, label)
loss += l
num_correct += acc
num_tests = len(test_images)
print('Test Loss:', loss / num_tests)
print('Test Accuracy:', num_correct / num_tests)