Esempio n. 1
0
	def test_mnist_28by28(self):
		import time
		import os
		import numpy as np
		import matplotlib.pyplot as plt
		from sklearn.cross_validation import train_test_split
		from sklearn.datasets import load_digits
		from sklearn.metrics import confusion_matrix, classification_report
		from sklearn.preprocessing import LabelBinarizer
		from ann import ANN

		# load lecun mnist dataset
		X = []
		y = []
		with open('data/mnist_test_data.txt', 'r') as fd, open('data/mnist_test_label.txt', 'r') as fl:
			for line in fd:
				img = line.split()
				pixels = [int(pixel) for pixel in img]
				X.append(pixels)
			for line in fl:
				pixel = int(line)
				y.append(pixel)
		X = np.array(X, np.float)
		y = np.array(y, np.float)

		# normalize input into [0, 1]
		X -= X.min()
		X /= X.max()

		# quick test
		#X = X[:1000]
		#y = y[:1000]

		# for my network
		X_test = X
		y_test = y #LabelBinarizer().fit_transform(y)

		nn = ANN([1,1])
		nn = nn.deserialize('28_200000.pickle') # '28_100000.pickle'

		predictions = []
		for i in range(X_test.shape[0]):
			o = nn.predict(X_test[i])
			predictions.append(np.argmax(o))

		# compute a confusion matrix
		print("confusion matrix")
		print(confusion_matrix(y_test, predictions))

		# show a classification report
		print("classification report")
		print(classification_report(y_test, predictions))
# free memory
X = None
y = None


def step_cb(nn, step):
	print("ping")
	nn.serialize(nn, str(step) + ".pickle")

# load or create an ANN
nn = ANN([1,1])
serialized_name = '28_1000000.pickle'

if os.path.exists(serialized_name):
	# load a saved ANN
	nn = nn.deserialize(serialized_name)
else:
	# create the ANN with:
	# 1 input layer of size 64 (the images are 8x8 gray pixels)
	# 1 hidden layer of size 100
	# 1 output layer of size 10 (the labels of digits are 0 to 9)
	nn = ANN([784, 300, 10])

	# see how long training takes
	startTime = time.time()

	# train it
	nn.train2(30, X_train_l, labels_train_l, 100000, step_cb)

	elapsedTime = time.time() - startTime
	print("Training took {0} seconds".format(elapsedTime))