-
Notifications
You must be signed in to change notification settings - Fork 1
/
8_run.py
31 lines (21 loc) · 966 Bytes
/
8_run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
from _8_mnist_sinkprop import DigitCNN
import ml_datasets
from tqdm import tqdm
trX, trY, teX, teY = ml_datasets.mnist(onehot=True)
Q = np.random.permutation(np.eye(784))
##########################################
# Permute the columns of the train and test matrices
##########################################
trX = np.dot(trX, Q)
teX = np.dot(teX, Q)
#trX = trX.reshape((-1, 1, 28, 28))
#teX = teX.reshape((-1, 1, 28, 28))
cnn = DigitCNN()
def train(iters=1, eta=.1, perm_eta=.1):
for i in xrange(iters):
for start, end in tqdm(zip(range(0, len(trX), 128), range(128, len(trX), 128))):
cost = cnn.train(trX[start:end], trY[start:end], eta=eta, perm_eta=perm_eta)
test_error = 1 - np.mean(np.argmax(teY, axis=1) == cnn.predict(teX))
train_error = 1- np.mean(np.argmax(trY[:10000], axis=1) == cnn.predict(trX[:10000]))
print ("test error: %s, train error %s" % (test_error, train_error))