forked from RBMLibrary/DBN-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example.py
38 lines (28 loc) · 978 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from data import MNIST
from dbn import DBN
import numpy as np
# Creates a DBN
d = DBN([28*28, 500,500, 2000], 10, 0.1)
# Loads the Images and Labels from the MNIST database
images = MNIST.load_images('images')
labels = MNIST.load_labels('labels')
# Gets the training set of images
img = images[0:60000]
lbl = labels[0:60000]
# Loads the test images from the MNIST database
tst_img = MNIST.load_images('test_images')
tst_lbl = MNIST.load_labels('test_labels')
# Pre trains the DBN
d.pre_train(img,5,50)
# Executes a 100 iterations of label training
# Attempts to classify and prints the error rate
for i in xrange(0, 100):
d.train_labels(img, lbl, 50, 50)
tst_class = d.classify(tst_img,10)
print 'Error over test data: {0}'.format(1 - (tst_class*tst_lbl).mean() * dbn.number_labels)
# Tests the DBN on test images
tst_class = d.classify(tst_img,20)
# Calculates the error
err_test = 1 - ((tst_class * tst_lbl).mean() * 10)
# Prints the error rate
print err_test