コード例 #1
0
def plotHeatmap(dataPath, gaussianDistances, ax1=None):
    b.rcParams['font.size'] = 20
    if ax1==None:
        b.figure(figsize=(8,15))
    else:
        b.sca(ax1)
        
    nE = 1600
    averagingWindowSize = 32
    
    spikeCount = np.zeros((len(gaussianDistances), nE))
    for i,dist in enumerate(gaussianDistances[:]):
        path = dataPath + '/dist'+str(dist)+'/' +'activity/'
        spikeCountTemp = np.load(path + 'spikeCountPerExample.npy')
        spikeCount[i,:] = spikeCountTemp[25,:,0]#np.loadtxt(path + 'spikeCountAe.txt')
#         spikeCount[i,:] = np.roll(spikeCount[i,:], int(0.25*len(spikeCount[i,:])))
        spikeCount[i,:] = movingaverage(spikeCount[i,:], averagingWindowSize)
        spikeCount[i,:] /= np.max(spikeCount[i,:])
    
    b.imshow(spikeCount[:,:], aspect='auto', extent=[0,1,2,0])
    b.colorbar()
    b.xlabel('Neuron number (resorted)')
    b.xlabel('Neuron number (resorted)')
    
    
    if ax1==None:
        b.savefig(dataPath + '/multipleAnglesHeatmap.png', dpi=300, bbox_inches='tight')
コード例 #2
0
def plot_2d_input_weights():
    name = 'XeAe'
    weights = get_2d_input_weights()
    fig = b.figure(fig_num, figsize = (18, 18))
    im2 = b.imshow(weights, interpolation = "nearest", vmin = 0, vmax = wmax_ee, cmap = cmap.get_cmap('hot_r'))
    b.colorbar(im2)
    b.title('weights of connection' + name)
    fig.canvas.draw()
    return im2, fig
コード例 #3
0
def plot_2d_input_weights():
    name = 'XeAe'
    weights = get_2d_input_weights()
    fig = b.figure(fig_num, figsize = (18, 18))
    im2 = b.imshow(weights, interpolation = "nearest", vmin = 0, vmax = wmax_ee, cmap = cmap.get_cmap('hot_r'))
    b.colorbar(im2)
    b.title('weights of connection' + name)
    fig.canvas.draw()
    return im2, fig
コード例 #4
0
def plot_2d_excitatory_weights():
	name = 'AeAe' + str(n_e)
	weights = get_2d_excitatory_weights()
	fig = b.figure(fig_num, figsize=(10, 10))
	im2 = b.imshow(weights, interpolation='nearest', vmin=0, vmax=wmax_ee, cmap=cmap.get_cmap('hot_r'))
	b.colorbar(im2)
	b.title('weights of connection ' + name)
	fig.canvas.draw()
	return im2, fig
コード例 #5
0
def plot_input(rates):
	'''
	Plot the current input example during the training procedure.
	'''
	fig = b.figure(fig_num, figsize = (5, 5))
	im = b.imshow(rates.reshape((28, 28)), interpolation = 'nearest', vmin=0, vmax=64, cmap=cmap.get_cmap('gray'))
	b.colorbar(im)
	b.title('Current input example')
	fig.canvas.draw()
	return im, fig
コード例 #6
0
def plot_2d_input_weights():
    '''
    Plot the weights from input to excitatory layer to view during training.
    '''
    weights = get_2d_input_weights()
    fig = b.figure(fig_num, figsize = (18, 18))
    im2 = b.imshow(weights, interpolation = "nearest", vmin = 0, vmax = wmax_ee, cmap = cmap.get_cmap('hot_r'))
    b.colorbar(im2)
    b.title('weights of connection ' + name)
    fig.canvas.draw()
    return im2, fig
コード例 #7
0
def plot_input():
    fig = b.figure(fig_num, figsize=(5, 5))
    im3 = b.imshow(rates.reshape((28, 28)),
                   interpolation='nearest',
                   vmin=0,
                   vmax=64,
                   cmap=cmap.get_cmap('gray'))
    b.colorbar(im3)
    b.title('Current input example')
    fig.canvas.draw()
    return im3, fig
コード例 #8
0
ファイル: snn_mnist.py プロジェクト: Hananel-Hazan/stdp-mnist
def plot_2d_input_weights():
    '''
    Plot the weights from input to excitatory layer to view during training.
    '''
    weights = get_2d_input_weights()
    fig = b.figure(fig_num, figsize=(18, 18))
    im = b.imshow(weights, interpolation='nearest', vmin=0, vmax=wmax_ee, cmap=cmap.get_cmap('hot_r'))
    b.colorbar(im)
    b.title('Reshaped weights from input -> excitatory layer')
    fig.canvas.draw()
    return im, fig
コード例 #9
0
def plot_patch_weights():
	'''
	Plot the weights between convolution patches to view during training.
	'''
	weights = get_patch_weights()
	fig, ax = b.subplots(figsize=(8, 8))
	im = ax.imshow(weights, interpolation='nearest', vmin=0, vmax=wmax_ee, cmap=cmap.get_cmap('hot_r'))
	b.colorbar(im)
	b.title('Between-patch connectivity')
	fig.canvas.draw()
	return im, fig
コード例 #10
0
def plot_patch_weights():
	'''
	Plot the weights between convolution patches to view during training.
	'''
	weights = get_patch_weights()
	fig = b.figure(fig_num, figsize=(8, 8))
	im = b.imshow(weights, interpolation='nearest', vmin=0, vmax=wmax_ee, cmap=cmap.get_cmap('hot_r'))
	for idx in xrange(n_e, n_e * conv_features, n_e):
		b.axvline(idx, ls='--', lw=1)
		b.axhline(idx, ls='--', lw=1)
	b.colorbar(im)
	b.title('Between-patch connectivity')
	fig.canvas.draw()
	return im, fig
コード例 #11
0
def plot_2d_input_weights():
	'''
	Plot the weights from input to excitatory layer to view during training.
	'''
	weights = get_2d_input_weights()
	fig = b.figure(fig_num, figsize=(18, 18))
	im = b.imshow(weights, interpolation='nearest', vmin=0, vmax=wmax_ee, cmap=cmap.get_cmap('hot_r'))
	for idx in xrange(conv_size * n_e_sqrt, conv_size * conv_features_sqrt * n_e_sqrt, conv_size * n_e_sqrt):
		b.axvline(idx, ls='--', lw=1)
		b.axhline(idx, ls='--', lw=1)
	b.colorbar(im)
	b.title('Reshaped input -> convolution weights')
	b.xticks(xrange(0, conv_size * conv_features_sqrt * n_e_sqrt, conv_size * n_e_sqrt))
	b.yticks(xrange(0, conv_size * conv_features_sqrt * n_e_sqrt, conv_size * n_e_sqrt))
	fig.canvas.draw()
	return im, fig
コード例 #12
0
def plot_2d_conv_weights():
    '''
    Plot the weights from input to excitatory layer to view during training.
    '''
    weights = get_2d_conv_weights()
    fig = b.figure(fig_num, figsize=(18, 18))
    im2 = b.imshow(weights,
                   interpolation='nearest',
                   vmin=0,
                   vmax=wmax_ee,
                   cmap=cmap.get_cmap('hot_r'))
    b.colorbar(im2)
    b.title(
        '2D weights (input -> convolutional, convolutional -> fully-connected)'
    )
    fig.canvas.draw()
    return im2, fig
def plot_2d_input_weights():
    '''
	Plot the weights from input to excitatory layer to view during training.
	'''
    weights, ordering = get_2d_input_weights()
    fig, ax = b.subplots(figsize=(18, 18))
    im = ax.imshow(weights,
                   interpolation='nearest',
                   vmin=0,
                   vmax=wmax_ee,
                   cmap=cmap.get_cmap('hot_r'))
    b.colorbar(im, fraction=0.016)
    b.title('Reshaped weights from input to convolutional layer', fontsize=18)
    b.xticks(xrange(conv_size, conv_size * (conv_features + 1), conv_size),
             xrange(1, conv_features + 1))
    b.yticks(xrange(conv_size, conv_size * (n_e + 1), conv_size),
             xrange(1, n_e + 1))
    b.xlabel('Sorted in order of similarity', fontsize=14)
    b.ylabel('Location in input (from top left to bottom right)', fontsize=14)
    fig.canvas.draw()
    return fig, ax, im, ordering
コード例 #14
0
    fig_num += 1
    if name[1]=='e':
        n_src = n_input
    else:
        n_src = n_i
    if name[3]=='e':
        n_tgt = n_e
    else:
        n_tgt = n_i
        
    w_post = np.zeros((n_src, n_tgt))
    connMatrix = connections[name][:]
    for i in xrange(n_src):
        w_post[i, connMatrix.rowj[i]] = connMatrix.rowdata[i]
    im2 = b.imshow(w_post, interpolation="nearest", vmin = 0, cmap=cmap.get_cmap('gist_ncar')) #my_cmap
    b.colorbar(im2)
    b.title('weights of connection' + name)
    
    
    
plot_2d_input_weights()


# error = np.abs(result_monitor[:,1] - result_monitor[:,0])
# correctionIdxs = np.where(error > 0.5)[0]
# correctedError = [1 - error[i] if (i in correctionIdxs) else error[i] for i in xrange(len(error))]
# correctedErrorSum = np.average(correctedError)
#     
# figure()
# scatter(result_monitor[:,1], result_monitor[:,0], c=range(len(error)), cmap=cm.gray)
# title('Error: ' + str(correctedErrorSum))
コード例 #15
0
    def plotResults(self):
        #------------------------------------------------------------------------------ 
        # plot results
        #------------------------------------------------------------------------------ 
        if self.rateMonitors:
            b.figure()
            for i, name in enumerate(self.rateMonitors):
                b.subplot(len(self.rateMonitors), 1, i)
                b.plot(self.rateMonitors[name].times/b.second, self.rateMonitors[name].rate, '.')
                b.title('rates of population ' + name)
            
        if self.spikeMonitors:
            b.figure()
            for i, name in enumerate(self.spikeMonitors):
                b.subplot(len(self.spikeMonitors), 1, i)
                b.raster_plot(self.spikeMonitors[name])
                b.title('spikes of population ' + name)
                if name=='Ce':
                    timePoints = np.linspace(0+(self.singleExampleTime+self.restingTime)/(2*b.second)*1000, 
                                             self.runtime/b.second*1000-(self.singleExampleTime+self.restingTime)/(2*b.second)*1000, 
                                             self.numExamples)
                    b.plot(timePoints, self.resultMonitor[:,0]*self.nE, 'g')
                    b.plot(timePoints, self.resultMonitor[:,1]*self.nE, 'r')
        
        if self.stateMonitors:
            b.figure()
            for i, name in enumerate(self.stateMonitors):
                b.plot(self.stateMonitors[name].times/b.second, self.stateMonitors[name]['v'][0], label = name + ' v 0')
                b.legend()
                b.title('membrane voltages of population ' + name)
            
        
            b.figure()
            for i, name in enumerate(self.stateMonitors):
                b.plot(self.stateMonitors[name].times/b.second, self.stateMonitors[name]['ge'][0], label = name + ' v 0')
                b.legend()
                b.title('conductances of population ' + name)
        
        plotWeights = [
        #                 'XeAe', 
        #                 'XeAi', 
        #                 'AeAe', 
        #                 'AeAi', 
        #                 'AiAe', 
        #                 'AiAi', 
        #                'BeBe', 
        #                'BeBi', 
        #                'BiBe', 
        #                'BiBi', 
        #                'CeCe', 
        #                'CeCi', 
                        'CiCe', 
        #                'CiCi', 
        #                'HeHe', 
        #                'HeHi', 
        #                'HiHe', 
        #                'HiHi', 
                        'AeHe',
        #                 'BeHe',
        #                 'CeHe',
                        'HeAe',
        #                 'HeBe',
        #                 'HeCe',
                       ]
        
        for name in plotWeights:
            b.figure()
#             my_cmap = matplotlib.colors.LinearSegmentedColormap.from_list('own2',['#f4f4f4', '#000000'])
#             my_cmap2 = matplotlib.colors.LinearSegmentedColormap.from_list('own2',['#000000', '#f4f4f4'])
            if name[1]=='e':
                nSrc = self.nE
            else:
                nSrc = self.nI
            if name[3]=='e':
                nTgt = self.nE
            else:
                nTgt = self.nI
                
            w_post = np.zeros((nSrc, nTgt))
            connMatrix = self.connections[name][:]
            for i in xrange(nSrc):
                w_post[i, connMatrix.rowj[i]] = connMatrix.rowdata[i]
            im2 = b.imshow(w_post, interpolation="nearest", vmin = 0, cmap=cm.get_cmap('gist_ncar')) #my_cmap
            b.colorbar(im2)
            b.title('weights of connection' + name)
            
            
        if self.plotError:
            error = np.abs(self.resultMonitor[:,1] - self.resultMonitor[:,0])
            correctionIdxs = np.where(error > 0.5)[0]
            correctedError = [1 - error[i] if (i in correctionIdxs) else error[i] for i in xrange(len(error))]
            correctedErrorSum = np.average(correctedError)
                 
            b.figure()
            b.scatter(self.resultMonitor[:,1], self.resultMonitor[:,0], c=range(len(error)), cmap=cm.get_cmap('gray'))
            b.title('Error: ' + str(correctedErrorSum))
            b.xlabel('Desired activity')
            b.ylabel('Population activity')
             
            b.figure()
            error = np.abs(self.resultMonitor[:,1] - self.resultMonitor[:,0])
            correctionIdxs = np.where(error > 0.5)[0]
            correctedError = [1 - error[i] if (i in correctionIdxs) else error[i] for i in xrange(len(error))]
            correctedErrorSum = np.average(correctedError)
            b.scatter(self.resultMonitor[:,1], self.resultMonitor[:,0], c=self.resultMonitor[:,2], cmap=cm.get_cmap('gray'))
            b.title('Error: ' + str(correctedErrorSum))
            b.xlabel('Desired activity')
            b.ylabel('Population activity')
        
        b.ioff()
        b.show()