-
Notifications
You must be signed in to change notification settings - Fork 0
/
toynn.py
56 lines (39 loc) · 1.35 KB
/
toynn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Toy fully-connected multi-layered neural networks
"""
import numpy as np
import matplotlib.pyplot as plt
from utils import generate_data
from layers import Layer, Network
from tqdm import trange
def build_fullyconnected(norm=np.inf, nhidden=5, fa1=True, fa2=True):
# generate data (in a box)
X, y = generate_data(norm=norm)
# build network
L1 = Layer((nhidden, 2), feedback_alignment=fa1)
L2 = Layer((1, nhidden), feedback_alignment=fa2)
net = Network(X, y, [L1, L2])
return net
def sim(norm, nh, fa1, fa2, numiter=10000):
net = build_fullyconnected(norm=norm, nhidden=nh, fa1=fa1, fa2=fa2)
return np.array([net() for _ in trange(numiter)]), net
if __name__ == '__main__':
# norm ball for generate toy data
norm = np.inf
objective, net = sim(norm, 50, True, True)
# predicted class labels (on held out data)
X_holdout, y_holdout = generate_data(norm=norm, nsamples=5000)
yhat = net.predict(X_holdout)[0]
# plot the training curve
plt.figure()
plt.plot(np.arange(objective.size), objective)
plt.xlabel('Iteration ($k$)')
plt.ylabel('Training error ($f(k)$)')
# plot labeled training data
plt.figure()
plt.scatter(X_holdout[0], X_holdout[1], s=50, c=yhat, cmap='seismic')
plt.gca().set_aspect('equal')
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.show()
plt.draw()