forked from rronan/Boltzmann-s-Cuisine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
unsupervised.py
144 lines (113 loc) · 4.36 KB
/
unsupervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# -*- coding: utf-8 -*-
"""
Created on Sun Jan 3 17:25:34 2016
@author: navrug
"""
import timeit
import numpy as np
import sys
from RBM import RBM
from sklearn.linear_model import LogisticRegression
import random
np.random.seed(seed=0)
learning_rate=0.01
training_epochs=50
batch_size=20
n_chains=20
output_folder='rbm_plots'
n_hidden=2000
dropout_rate=0.5
k=20
do_report = True
# Create a report to be saved at the end of execution (when running on the
# remote server)
if do_report:
report = {"learning_rate":learning_rate,
"training_epochs":training_epochs,
"batch_size":batch_size,
"n_chains":n_chains,
"output_folder":'rbm_plots',
"n_hidden":n_hidden,
"dropout_rate":dropout_rate,
"k":k,
"costs":np.zeros(training_epochs),
"accuracy":np.zeros(training_epochs),
"pretraining_time":0}
data = np.load('train_data.npy')
n_labels = 20
n_visible = data.shape[1]-n_labels
# Split of train_data for cross-validation
n_fold = 3
test_n = int(data.shape[0]/n_fold)
random.seed(0)
permutation = np.random.permutation(data.shape[0])
test_idx = permutation[:test_n]
test_set = data[test_idx,:]
train_idx = permutation[test_n:]
train_set = data[train_idx,:]
del data
test_labels = np.argmax(test_set[:,:n_labels], axis=1)
train_labels = np.argmax(train_set[:,:n_labels], axis=1)
# compute number of minibatches for training, validation and testing
batches = [train_set[i:i + batch_size,n_labels:] \
for i in range(0, train_set.shape[0], batch_size)]
rng = np.random.RandomState(123)
# construct the RBM class
rbm = RBM(n_visible=n_visible,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
batch_size=batch_size,
np_rng=rng)
#%%============================================================================
# Training the RBM
#==============================================================================
start_time = timeit.default_timer()
accuracies = []
argmax_acc = 0
for epoch in xrange(training_epochs):
epoch_time = timeit.default_timer()
mean_cost = []
for batch_index, batch in enumerate(batches):
rbm.update(batch, persistent=True, k=k)
sys.stdout.write("\rEpoch advancement: %d%%" % (100*float(batch_index)/len(batches)))
sys.stdout.flush()
# Training Logistic regression
sys.stdout.write("\rTraining softmax...")
sm_time = timeit.default_timer()
softmax_classifier = LogisticRegression(penalty='l1',
C=1.0,
solver='lbfgs',
multi_class='multinomial')
softmax_classifier.fit(rbm.propup(train_set[:,n_labels:], np.ones((len(train_set),n_hidden))),
train_labels)
sys.stdout.write('\rSoftmax trained in %f minutes.\n' % ((timeit.default_timer()-sm_time) / 60.))
sys.stdout.write("Evaluating accuracy...")
cv_time = timeit.default_timer()
acc = softmax_classifier.score(rbm.propup(test_set[:,n_labels:], np.ones((len(test_set),n_hidden))),
test_labels)
accuracies.append(acc)
sys.stdout.write('''\rEpoch %i took %f minutes,
accuracy (computed in %f minutes) is %f.\n'''
% (epoch, ((cv_time-epoch_time) / 60.),
((timeit.default_timer()-cv_time) / 60.), acc))
if do_report:
report["costs"][epoch] = np.mean(mean_cost)
report["accuracy"][epoch] = acc
if (acc>argmax_acc):
report["W"] = rbm.W
report["hbias"] = rbm.hbias
report["vbias"] = rbm.vbias
np.save('report', report)
sys.stdout.write("Model saved \n")
end_time = timeit.default_timer()
pretraining_time = (end_time - start_time)
report["pretraining_time"] = pretraining_time
print ('Training took %f minutes' % (pretraining_time / 60.))
if do_report:
np.save('report', report)
#%%============================================================================
# Classifying with the RBM
#==============================================================================
#%%============================================================================
# Sampling from the RBM
#==============================================================================