/
neuronaut_plot.py
125 lines (98 loc) · 4.05 KB
/
neuronaut_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import numpy
from prettyplotlib import brewer2mpl
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
import logging
from PIL import Image
class WeightsPlotter(object):
def __init__(self, scale=10):
self.scale = scale
self.canvas = None
def plot_weights(self, weights):
weights = [numpy.concatenate((w, b.reshape(1, -1)), axis=0) for (w, b) in weights]
if self.canvas is None:
height = max([w.shape[0] for w in weights])
width = numpy.sum([w.shape[1] for w in weights]) + 2*(len(weights)-1)
self.canvas = numpy.empty((height, width), dtype=numpy.float32)
self.canvas[:] = numpy.nan
off_x = 0
for i, w in enumerate(weights):
width = w.shape[1]
height = w.shape[0]
self.canvas[:height,off_x:off_x+width] = w
off_x += width + 2
return self.canvas
###
from multiprocessing import Process, Queue
class IterableQueue(object):
def __init__(self, queue):
self.queue = queue
def __iter__(self):
while True:
x = self.queue.get()
if x is None:
break
else:
yield x
def __len__(self):
return 999999999999999 # workaround ...
def render_confusion(file_name, queue, vmin, vmax, divergent, array_shape):
from pylab import plt
import matplotlib.animation as animation
plt.close()
fig = plt.figure()
def update_img((expected, output)):
plt.cla()
plt.ylim((vmin, vmin+vmax))
plt.xlim((vmin, vmin+vmax))
ax = fig.add_subplot(111)
plt.plot([vmin, vmin+vmax], [vmin, vmin+vmax])
ax.grid(True)
plt.xlabel("expected output")
plt.ylabel("network output")
plt.legend()
expected = expected*vmax + vmin
output = output*vmax + vmin
#scat.set_offsets((expected, output))
scat = ax.scatter(expected, output)
return scat
ani = animation.FuncAnimation(fig, update_img, frames=IterableQueue(queue))
ani.save(file_name, fps=30, extra_args=['-vcodec', 'libvpx', '-threads', '4', '-b:v', '1M'])
def render_weights(file_name, queue, vmin, vmax, divergent, array_shape):
from pylab import plt
import matplotlib.animation as animation
plotter = WeightsPlotter()
fig = plt.figure(facecolor='gray', frameon=False)
ax = fig.add_subplot(111)
ax.set_aspect('equal')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax.axis('off')
if divergent:
my_colormap = brewer2mpl.get_map('PiYG', 'diverging', 11).mpl_colormap
else:
my_colormap = brewer2mpl.get_map('OrRd', 'sequential', 9).mpl_colormap
my_colormap.set_bad('#6ECFF6', 1.0)
im = {}
def update_img(array):
array = plotter.plot_weights(array)
if im.get('im', None) is None:
im['im'] = ax.imshow(array, cmap=my_colormap, interpolation='nearest', vmin=vmin, vmax=vmax)
aspect = array.shape[0] / float(array.shape[1])
fig.set_size_inches([7.2, 7.2*aspect]) # 720 pixels wide, variable height
fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=None, hspace=None)
im['im'].set_data(array)
return im['im']
#legend(loc=0)
ani = animation.FuncAnimation(fig,update_img, frames=IterableQueue(queue))
#writer = animation.writers['ffmpeg'](fps=30, bitrate=27*1024)
ani.save(file_name, fps=30, extra_args=['-vcodec', 'libvpx', '-threads', '4', '-b:v', '1M'])
class AnimationRender(Process):
def __init__(self, file_name, vmin=-6.0, vmax=6.0, divergent=True, render_function=render_weights, array_shape=()):
self.queue = Queue(maxsize=100)
super(AnimationRender, self).__init__(target=render_function, args=(file_name, self.queue, vmin, vmax, divergent, array_shape))
def add_frame(self, frame):
self.queue.put(frame)
def join(self):
self.queue.put(None) # signal the end
return super(AnimationRender, self).join()