Exemplo n.º 1
0
def main(rand_nodes=False,
         script_name=__file__.split('/')[-1][0:-3],
         from_disk=1):

    k = get_kwargs_builder()
    dinfo, dn = get_setup(**{'rand_nodes': rand_nodes})

    #     dinfo, dn = get_setup()
    ds = get_storages(script_name, dn.keys(), dinfo)

    d = {}
    d.update(optimize('opt_rate', dn, [from_disk] * 5, ds, **{'x0': 200.0}))

    for key in sorted(dn['hist'].keys()):
        net = dn['hist'][key]
        set_optimization_val(d['opt_rate'][net.get_name()], [net])
    d.update(
        run('hist', dn, [from_disk] * 5, ds, 'mean_rates',
            **{'t_start': k['start_rec']}))

    fig, axs = pl.get_figure(n_rows=1,
                             n_cols=1,
                             w=1000.0,
                             h=800.0,
                             fontsize=16)

    show_opt_hist(d, axs, '$GPe_{+d}^{TA}$')
    ds['fig'].save_fig(fig)

    if DISPLAY: pylab.show()
Exemplo n.º 2
0
def main(rand_nodes=False, 
         script_name= __file__.split('/')[-1][0:-3], 
         from_disk=0):   
      
    k=get_kwargs_builder()
    
    dinfo, dn = get_setup(**{'rand_nodes':rand_nodes})

    dn=modify(dn)
    ds = get_storages(script_name, dn.keys(), dinfo)

    dstim={}
    dstim ['IV']=map(float, range(-300,300,100)) #curr
    dstim ['IF']=map(float, range(0,500,100)) #curr
    dstim ['FF']=map(float, range(0,1500,100)) #rate
  
    d={}
    d.update(run_XX('IV', dn, [from_disk]*4, ds, dstim))
    d.update(run_XX('IF', dn, [from_disk]*4, ds, dstim))
    d.update(run_XX('FF', dn, [from_disk]*4, ds, dstim))   
    d.update(optimize('opt_rate', dn, [from_disk]*1, ds, **{ 'x0':200.0}))   
    set_optimization_val(d['opt_rate']['Net_0'], dn['hist']) 
    d.update(run('hist', dn, [from_disk]*2, ds, 'mean_rates', 
                 **{'t_start':k['start_rec']}))                 

    
    
    
    fig, axs=pl.get_figure(n_rows=2, n_cols=2, w=1000.0, h=800.0, fontsize=16) 

    show(dstim, d, axs, NAMES)
    ds['fig'].save_fig(fig)
    
    if not os.environ.get('DISPLAY'): pylab.show()   
Exemplo n.º 3
0
def show_3d(d,attr,**k):
    models=['SN']
    res=k.get('resolution')
    titles=k.get('titles')
    n=len(models)
    m=len(d.keys())
#     attr='mean_rate_slices'
    fig, axs=ps.get_figure(n_rows=m, n_cols=1, w=500.0, h=800.0, fontsize=12, 
                           projection='3d')        
     
    i=0
    
    for model in models:
        alpha=0.8
        dcm={'Net_0':'jet',
             'Net_1':'coolwarm',}
        for key in sorted(d.keys()):
            obj0=d[key]['set_0'][model][attr]
            obj1=d[key]['set_1'][model][attr]
            args=[obj0.x_set, obj1.x_set,
                  # obj1.y-obj0.y, 
                  numpy.mean(obj1.y_raw_data-obj0.y_raw_data, axis=0),
                  numpy.std(obj1.y_raw_data-obj0.y_raw_data, axis=0)]
            for j, arg in enumerate(args):
                arg.shape
                args[j]=numpy.reshape(arg, [res,res])
            x,y,z, z_std=args
             
            axs[i].plot_surface(x, y, z, cmap='coolwarm', rstride=1, cstride=1, 
                                linewidth=0, 
                                shade=True,
                                alpha=alpha,
                                antialiased=False,
                                vmin=-40,
                                vmax=40)
            axs[i].set_zlim([-40, 40])
#             axs[i].set_zlabel('SNr firing rate (Hz)')
#             axs[i].set_xlabel('CTX increase A1 (proportion)')
#             axs[i].set_ylabel('CTX increase A2 (proportion)')
            
            
            axs[i].set_title(titles[i])
#             axs[i+1].plot_surface(x, y, z_std, cmap='coolwarm', rstride=1, cstride=1, 
#                                 linewidth=0, 
#                                 shade=True,
#                                 alpha=alpha,
#                                 antialiased=False)
             
             
    #                 alpha-=0.3
            i+=1
    #                 pylab.show()
    #                 print v
               
    for ax in axs:
        ax.view_init(elev=15)
    
    return fig
Exemplo n.º 4
0
def show_hist(name, d, models=['M1', 'M2', 'FS', 'GA', 'GI', 'ST', 'SN'], **k):

    fig, axs = ps.get_figure(n_rows=len(models),
                             n_cols=1,
                             w=1000.0,
                             h=800.0,
                             fontsize=10)
    labels = k.pop('labels', sorted(d.keys()))

    colors = misc.make_N_colors('jet', len(labels))
    linestyles = ['solid'] * len(labels)
    linewidth = [2.0] * len(labels)
    j = 0

    for key in sorted(d.keys()):
        v = d[key]
        #         axs[0].serunt_title(k)

        for i, model in enumerate(models):
            if 'spike_stastistic' in v[model]:
                st = v[model]['spike_statistic']
                st.rates = {
                    'mean': round(st.rates['mean'], 2),
                    'std': round(st.rates['std'], 2),
                    'CV': round(st.rates['CV'], 2)
                }
                s = str(st.rates)
            else:
                s = ''
#             print s
#             print labels[j]
            k.update({
                'label': (model + ' ' + labels[j] + ' ' + s),
                'histtype': 'step',
                'linestyle': linestyles[j],
                'color': colors[j],
                'linewidth': linewidth[j]
            })
            #             print k
            h = v[model][name].hist(ax=axs[i], **k)

            ylim = list(axs[i].get_ylim())
            ylim[0] = 0.0
            axs[i].set_ylim(ylim)
            axs[i].legend_box_to_line()
        j += 1


#     import pylab
#     pylab.show()
    return fig, axs
Exemplo n.º 5
0
def plot_optimal(size):
    fig, axs=ps.get_figure(n_rows=1, n_cols=1, w=400.0*0.8, h=250.0*0.8, fontsize=24)
    x,y,z=get_optimal(size)
       
    im=axs[0].pcolor(x, y, z, cmap='coolwarm',  vmin=-40, vmax=40)
#           
    box = axs[0].get_position()
    axs[0].set_position([box.x0*1.05, box.y0, box.width*0.85, box.height])
 
    axs[0].set_xlabel('Action 1')
    axs[0].set_ylabel('Action 2')
    axs[0].my_set_no_ticks(xticks=3,yticks=3, )
            
    return fig
Exemplo n.º 6
0
def plot_raw(d, d_keys, attr='mean_coherence'):

    for key in sorted(d.keys()):
        fig, axs = ps.get_figure(n_rows=1,
                                 n_cols=4,
                                 w=800.0 * 0.65 * 2,
                                 h=300.0 * 0.65 * 2,
                                 fontsize=8,
                                 frame_hight_y=0.5,
                                 frame_hight_x=0.9,
                                 title_fontsize=8)

        for i, model in enumerate(models):
            ax = axs[i]
            ax.plot(d[key]['Net_0'][model][attr], 'b')
            ax.plot(d[key]['Net_1'][model][attr], 'r')
            ax.set_title(d_keys[key] + ' ' + model)
Exemplo n.º 7
0
def plot_raw(d, d_keys, attr='mean_coherence'):
    
    for key in sorted(d.keys()):
        fig, axs=ps.get_figure(n_rows=1, 
                               n_cols=4, 
                               w=800.0*0.65*2, 
                               h=300.0*0.65*2, 
                               fontsize=8,
                               frame_hight_y=0.5, 
                               frame_hight_x=0.9, 
                               title_fontsize=8)        
        
        for i, model in enumerate(models):
            ax=axs[i]
            ax.plot(d[key]['Net_0'][model][attr], 'b')
            ax.plot(d[key]['Net_1'][model][attr], 'r')
            ax.set_title(d_keys[key]+' '+model)
Exemplo n.º 8
0
def show_neuron_numbers(d, models, **k):
    attr='firing_rate'    
    linestyle=['-','--']
    
    labels=k.pop('labels', models) 
    colors=misc.make_N_colors('jet', max(len(labels), 1))
    
    fig, axs=ps.get_figure(n_rows=1, n_cols=1, w=500.0, h=500.0, fontsize=10)    
    l_ids=[]
    ax=axs[0]
    for k, model in enumerate(models):

        max_set=0
        
        for j, name in enumerate(sorted([d.keys()[0]])):
            v=d[name]
            sets=[s for s in sorted(v.keys()) if s[0:3]=='set']
            
            ids=0
            for i, _set in enumerate(sets):
                
                if not model in v[_set].keys():
                    break
                
                obj=v[_set][model][attr]
                
                ids+=len(obj.ids)
                if max_set<=i:
                    max_set+=1            
        l_ids.append(ids)
    
    
     
    obj=Data_bar(**{'y':numpy.array(l_ids)})   
    obj.bar(ax)

    ax.set_xticklabels(labels)
    return fig
Exemplo n.º 9
0
def show_bulk(d, models, attr, **k):
    
#     attr='mean_rate_slices'    
    linestyle=['-','--']
    
    labels=k.pop('labels', sorted(d.keys()))  
    colors=misc.make_N_colors('jet', max(len(labels), 1))
    
    fig, axs=ps.get_figure(n_rows=7, n_cols=1, w=1200.0, h=800.0, fontsize=10)    
    res=k.pop('res')
    i=0
    xticks=[]
    for j in range(res):
        for k in range(res):
            if j==k:
                xticks.append(i)
            i+=1
                    
    
    for k, model in enumerate(models):        
        ax=axs[k]
#         max_set=0
        for j, name in enumerate(sorted(d.keys())):
            v=d[name]
            sets=[s for s in v.keys() if s[0:3]=='set']
            for i, _set in enumerate(sets):
                
                if not model in v[_set].keys():
                    break
                
                obj=v[_set][model][attr]
                obj.plot(ax, **{'color':colors[j],
                                'linestyle':linestyle[i],
                                'label':labels[j]
                                })
#                 if max_set<=i:
#                     max_set+=1
        
        max_set=len(sets)
        ax.set_ylabel(model+' (spikes/s)')
        ax.set_xticks(xticks)
        #Get artists and labels for legend and chose which ones to display
        if k!=0:
            continue    
        handles, _labels = ax.get_legend_handles_labels()
        display = range(0,len(d.keys())*max_set, max_set)
            

        #Create custom artists
        linetype_labels=[]
        linetype_handles=[]
        for i in range(max_set):
            linetype_handles.append(pylab.Line2D((0,1),(0,0),
                                                 color='k', 
    
                                                 linestyle=linestyle[i]))
            linetype_labels.append('Action '+str(i))

                
        #Create legend from custom artist/label lists
        ax.legend([handle for i,handle in enumerate(handles) 
                   if i in display]+linetype_handles,
          [label for i,label in enumerate(_labels) 
                   if i in display]+linetype_labels, bbox_to_anchor=(1.22, 1.))
        
        
    return fig       
Exemplo n.º 10
0
'''
Created on May 8, 2014

@author: lindahlm
'''

import pylab
import numpy
import pprint
import core.plot_settings as pl
pp = pprint.pprint

_, axs = pl.get_figure(n_rows=1, n_cols=1, w=1000.0, h=800.0, fontsize=16)

ax = axs[0]  #pylab.subplot(111)
r = numpy.random.random(100)
ax.hist(r, **{'histtype': 'step', 'label': 'test 1'})
r = numpy.random.random(100)
ax.hist(r, **{'histtype': 'step', 'label': 'test 2', 'linestyle': 'dashed'})

h, labels = ax.get_legend_handles_labels()

artist = []
for hh in h:
    color = hh._edgecolor
    linestyle = hh._linestyle
    obj = pylab.Line2D((0, 1), (0, 0), color=color, linestyle=linestyle)
    artist.append(obj)

ax.legend(artist, labels)
Exemplo n.º 11
0
'''
Created on Jul 27, 2014

@author: mikael
'''
import numpy
import core.plot_settings as ps
import pylab
y = numpy.array([
    [1041, 2422, 2009, 1980],
    [1061, 2387, 1968, 1856],
    [1014, 2167, 1960, 1788],
])
x = numpy.array([8, 6, 4])
fig, axs = ps.get_figure(n_rows=1,
                         n_cols=1,
                         w=500.0,
                         h=500.0,
                         fontsize=24,
                         linewidth=4)
ax = axs[0]
for i in range(4):
    ax.plot(x, y[:, i] / float(numpy.max(y)))

ax.my_set_no_ticks(xticks=3)
ax.set_ylabel('Rel performance')
ax.set_xlabel('MSN activated (%)')
pylab.show()
Exemplo n.º 12
0
'''
Created on Jul 27, 2014

@author: mikael
'''
import numpy
import core.plot_settings as ps
import pylab
y=numpy.array([[1041, 2422, 2009, 1980],
               [1061, 2387, 1968, 1856 ],
   [1014, 2167, 1960, 1788],   
   ])
x=numpy.array([8,6,4])
fig, axs=ps.get_figure(n_rows=1, n_cols=1, w=500.0, h=500.0, 
                        fontsize=24, linewidth=4)    
ax=axs[0]
for i in range(4):
    ax.plot(x, y[:,i]/float(numpy.max(y)))

ax.my_set_no_ticks(xticks=3) 
ax.set_ylabel('Rel performance')   
ax.set_xlabel('MSN activated (%)')   
pylab.show()
Exemplo n.º 13
0
'''
Created on May 8, 2014

@author: lindahlm
'''

import pylab
import numpy
import pprint
import core.plot_settings as pl
pp=pprint.pprint


_, axs=pl.get_figure(n_rows=1, n_cols=1, w=1000.0, h=800.0, fontsize=16) 


ax=axs[0]#pylab.subplot(111)
r=numpy.random.random(100)
ax.hist(r, **{'histtype':'step', 'label':'test 1'})
r=numpy.random.random(100)
ax.hist(r, **{'histtype':'step', 'label':'test 2', 'linestyle':'dashed'})


h, labels=ax.get_legend_handles_labels()

artist=[]
for hh in h:
    color=hh._edgecolor
    linestyle=hh._linestyle
    obj=pylab.Line2D((0,1),(0,0), color=color, linestyle=linestyle)
    artist.append(obj)