Exemplo n.º 1
0
def show_rates(d, **k):

    scale=k.get('scale',2)
    
    kw={'n_rows':1, 
        'n_cols':5, 
        'w':int(72/2.54*17.6)*2*scale, 
        'h':int(72/2.54*17.6)*1.5*scale, 
        'fontsize':7*scale,
        'title_fontsize':7*scale,
        'frame_hight_y':0.5,
        'frame_hight_x':0.7,
        'gs_builder':gs_builder}
    
    fig, axs=ps.get_figure2(**kw) 
    
    i=0
    images=[]
    for tp in ['beta', 'sw']:
        for net in sorted(d[tp].keys()):
            ax=axs[i]
            dd=d[tp][net]
            im=plot_rates(ax, dd['mean_rates'], dd['xlabels'], dd['ylabels'])
            ax.set_title(tp+' '+net)
            im.set_clim([0, 35])
            images.append(im)
            i+=1
    for i in [1,3]:
        axs[i].my_remove_axis(xaxis=False, yaxis=True,keep_ticks=True)
    
    ax=axs[1]
    im=images[1]
    box = ax.get_position()
    axColor = pylab.axes([box.x0 + box.width *1.03, 
                          box.y0+box.height*0.1, 
                          0.01, 
                          box.height*0.8])
    cbar=pylab.colorbar(im, cax = axColor, orientation="vertical")
    cbar.ax.set_ylabel('Rate (Hz)', rotation=270)
    from matplotlib import ticker
    
    tick_locator = ticker.MaxNLocator(nbins=4)
    cbar.locator = tick_locator
    cbar.update_ticks()
    cbar.ax.tick_params(direction='in', length=1, width=0.5) 
    
    return fig
Exemplo n.º 2
0
def get_fig_axs(scale=4):
    kw={'n_rows':2, 
        'n_cols':1, 
        'w':72/2.54*11.6*scale, 
        'h':150*scale, 
        'fontsize':7*scale,
        'frame_hight_y':0.5,
        'frame_hight_x':0.7,
        'title_fontsize':7*scale,
        'font_size':7*scale,
        'text_fontsize':7*scale,
        'linewidth':1.*scale,
        'gs_builder':gs_builder}
#     kwargs_fig=kwargs.get('kwargs_fig', kw)
    from core import plot_settings as ps
    fig, axs=ps.get_figure2(**kw) 
    return fig, axs
Exemplo n.º 3
0
def get_fig_axs():
    scale = 4
    kw = {
        'n_rows': 4,
        'n_cols': 2,
        'w': 72 / 2.54 * 18 * scale,
        'h': 300 * scale,
        'fontsize': 7 * scale,
        'frame_hight_y': 0.5,
        'frame_hight_x': 0.7,
        'title_fontsize': 7 * scale,
        'font_size': 7 * scale,
        'text_fontsize': 7 * scale,
        'linewidth': 1. * scale,
        'gs_builder': gs_builder
    }
    #     kwargs_fig=kwargs.get('kwargs_fig', kw)
    from core import plot_settings as ps
    fig, axs = ps.get_figure2(**kw)
    return fig, axs
               'fs':256.*4, 
               'noverlap':128*4/2},
        'oi_min':.5,
        'oi_max':1.5,
        'oi_upper':1000.,
        'oi_fs':256*4,
        'keep':['data'],
        'compute_performance_name_and_x':lambda x: [x[0],1]
                   }
    return d

fig, axs=ps.get_figure2(n_rows=4,
                         n_cols=3,  
                         w=int(72/2.54*11.6*(1+1./2+0.2))*scale,
                         h=int(0.85*72/2.54*11.6*(1+1./2))*scale,
#                             w=k.get('w',500), 
#                             h=k.get('h',900), 
                        linewidth=1,
                        fontsize=7*scale,
                        title_fontsize=7*scale,
                        gs_builder=gs_builder) 

models=['M1', 'M2', 'FS', 'GA', 'GF', 'GI', 'GP', 'ST','SN',
        'GI_MS', 
        'GA_MS', 
        'GF_MS', 
        'GI_FS', 
        'GA_FS', 
        'GF_FS']

# models=[m for m in models if not ( m in exclude)]
Exemplo n.º 5
0
def plot_coher(d, labelsy, labelsx=[], title_name='Slow wave'):
    fig, axs = ps.get_figure2(
        n_rows=9,
        n_cols=12,
        #w=700,
        w=23 / 56. * 17.6 * 72 / 2.54,
        h=23 / 56. * 17.6 * 72 / 2.54,
        #                             h=700,
        fontsize=7,
        frame_hight_y=0.5,
        frame_hight_x=0.7,
        title_fontsize=7,
        gs_builder=gs_builder,
        linewidth=1.)
    for ax in axs:
        ax.tick_params(direction='in',
                       length=2,
                       width=0.5
                       #                        top=False, right=False
                       )

#     fig, axs=ps.get_figure(n_rows=4, n_cols=1, w=500.0*0.65*2, h=400.0*0.65*2, fontsize=16,
#                            frame_hight_y=0.6, frame_hight_x=0.8,
#                            title_fontsize=20, text_usetex=False)

#     nice_labels={'CTX_M1':r'CTX$\to$$MSN_{D1}$ ',
#                  'CTX_M2':r'CTX$\to$$MSN_{D2}$ ',
#                  'CTX_ST':r'CTX$\to$STN ',
#                  'MS_MS':r'MSN$\to$MSN ',
#                  'M1':r'$MSN_{D1}$ ',
#                  'M2':r'$MSN_{D2}$ ',
#                  'GP':r'GPe ',
#                  'SN':r'SNr ',
#                  'FS_M2':r'FSN$\to$$MSN_{D2}$ ',
#                  'GP_ST':r'GPe$\to$STN ',
#                  'GP_FS':r'GPe$\to$FSN ',
#                  'GP_GP':r'GPe$\to$GPe ',
#                  'FS_FS':r'FSN$\to$FSN ',
#                  'M1_SN':r'$MSN_{D1}$$\to$SNr ',
#                  'ST_GP':r'STN$\to$GPe ',}

    from scripts_inhibition.base_effect_conns import nice_labels
    groupings = ['Coherence', 'Phase relation']
    #
    #     nice_labels2={'GA_GA':r'TA vs TA',
    #                   'GI_GA':r'TI vs TA',
    #                   'GI_GI':r'TI vs TI',
    #                   'GP_GP':r'GP vs GP'}

    for i in range(len(labelsy)):
        if labelsy[i] in nice_labels(version=0).keys():
            labelsy[i] = nice_labels(version=0)[labelsy[i]]

    l0 = []
    l2 = []
    for key in sorted(d['bar_obj'].keys()):

        #         if key in ['GA_GA', 'GI_GA', 'GP_GP']:
        #             d['bar_obj'][key].y[1,:]+=1
        l0.append(d['bar_obj'][key].y[0, :])
        l2.append(d['bar_obj'][key].y[1, :])
        labelsx.append(key)

    z = numpy.transpose(numpy.array(l0 + l2))

    for i in range(len(labelsx)):
        if labelsx[i] in nice_labels(version=1).keys():
            labelsx[i] = nice_labels(version=1)[labelsx[i]]

    _vmin = 0
    _vmax = 4
    stepx = 1
    stepy = 1
    startx = 0
    starty = 0
    stopy = 14
    stopx = 8
    maxy = 14
    maxx = 8

    posy = numpy.linspace(0.5, maxy - 0.5, maxy)
    posx = numpy.linspace(0.5, maxx - 0.5, maxx)
    axs[1].barh(posy,
                numpy.mean(z, axis=1)[::-1],
                align='center',
                color='0.5',
                edgecolor='0.5')
    axs[1].plot([1, 1], [0, stopy], 'k', linestyle='--')
    x1, y1 = numpy.meshgrid(numpy.linspace(startx, stopx, maxx + 1),
                            numpy.linspace(stopy, starty, maxy + 1))

    im = axs[0].pcolor(x1, y1, z, cmap='jet', vmin=_vmin, vmax=_vmax)
    axs[0].set_yticks(posy)
    axs[0].set_yticklabels(labelsy[::-1])
    axs[0].set_xticks(posx)
    axs[0].set_xticklabels(labelsx * 2, rotation=70, ha='right')
    axs[0].set_ylim([0, maxy])
    axs[0].text(0.25,
                -0.39,
                "Coherence",
                ha='center',
                transform=axs[0].transAxes)
    axs[0].text(0.75,
                -0.39,
                "Phase shift",
                ha='center',
                transform=axs[0].transAxes)
    axs[1].text(0.5,
                -0.15,
                "Mean",
                transform=axs[1].transAxes,
                ha='center',
                rotation=0)
    axs[1].text(0.5,
                -0.22,
                "effect",
                transform=axs[1].transAxes,
                ha='center',
                rotation=0)
    font0 = FontProperties()
    font0.set_weight('bold')
    axs[1].text(
        1.3,
        0.5,
        title_name,
        #                 fontsize=28,
        va='center',
        transform=axs[0].transAxes,
        rotation=270,
        fontproperties=font0)
    axs[0].text(-0.58,
                0.5,
                "Connection without dop. effect",
                transform=axs[0].transAxes,
                rotation=90,
                ha='center',
                va='center')

    axs[1].my_remove_axis(xaxis=False, yaxis=True)
    axs[1].my_set_no_ticks(xticks=2)
    axs[1].set_xlim([0, maxy])
    axs[1].set_xlim([0, 2])
    #     axs[1].set_xticks([0.04, 0.12])

    box = axs[0].get_position()
    axColor = pylab.axes([
        box.x0 + 0.0 * box.width, box.y0 + box.height + box.height * 0.15,
        box.width * 0.8, 0.02
    ])
    #     axColor = pylab.axes([0.05, 0.9, 1.0, 0.05])
    cbar = pylab.colorbar(im, cax=axColor, orientation="horizontal")
    cbar.ax.set_title('Deviation from base model MSE')  #, rotation=270)
    cbar.ax.tick_params(direction='in',
                        length=1,
                        width=0.5
                        #                        top=False, right=False
                        )
    from matplotlib import ticker
    tick_locator = ticker.MaxNLocator(nbins=4)
    cbar.locator = tick_locator
    cbar.update_ticks()

    return fig
Exemplo n.º 6
0
# 
# 
# 
# 
# # show_heat_map(dd, 'mean_rate_slices', **k)
# # show_variability(dd, 'mean_rate_slices', **k)
# # show_variability(dd, 'mean_rate_slices', net='Net_01',**k)
# show_variability(dd, 'mean_rate_slices', net='Net_03',**k)
# show_variability(dd, 'mean_rate_slices', net='Net_05',**k)
# show_variability(dd, 'mean_rate_slices', net='Net_06',**k)

scale=4
fig, axs=ps.get_figure2(n_rows=4, 
                        n_cols=2,
                        w=int(72/2.54*8.9) *scale,
                        h=450*scale,  
                        fontsize=7*scale,
                        title_fontsize=7*scale,
#                         grid=[3,2],
                        gs_builder=gs_builder_new) 




print len(axs)
# pylab.show()
k={'axs':axs,
#    'do_colorbar':False, 
   'fig':fig,
   'model':'SN',
   'print_statistics':False,
    'resolution':10,
Exemplo n.º 7
0
def create_figs(setup, file_name_figs, d, models):
    
    sd_figs = Storage_dic.load(file_name_figs)
    figs = []
    
#     d_plot_fr = setup.plot_fr()
    d_plot_fr2=setup.plot_fr2()
#     d_plot_mr=setup.plot_mr()
    d_plot_mr_diff=setup.plot_mr_diff()


    fig, axs=ps.get_figure2(**setup.plot_fig_axs().get('fig_and_axes'))
    figs.append(fig)
    for ax in axs:
        ax.tick_params(direction='out',
                       length=2,
                       width=0.5,
                       pad=0.01,
                        top=False, right=False
                        )
#     for name in sorted(d.keys()):
#         if name=='Difference':
#             continue        
#         for model in d[name]['set_0'].keys():
#             v={'firing_rate':d[name]['set_0'][model]['firing_rate']}
#             d['Net_0'][model]=v
    
#     figs.append(show_fr(d, models, **d_plot_fr))
    show_fr(d, models, axs[0:2], **d_plot_fr2)
    axs=figs[-1].get_axes()
    for i, ax in enumerate(axs[2:4]):
        if i==0:
            ax.set_ylim([0,70])
        else: 
            ax.set_ylim([0,70])
        ax.my_set_no_ticks(xticks=4)
    
    show_mr_diff(d, models, axs[2:4], **d_plot_mr_diff)
#     axs=figs[-1].get_axes()
    for i, ax in enumerate(axs[2:4]):
        handles, labels=ax.get_legend_handles_labels() 
        ax.legend(handles, labels,loc='upper left')

    for i, s, c0, c1, rotation in [
#                                    [2, 'Ralative', -0.58, 0., 90],
#                                    [2, 'decrease (Hz)', -0.39, 0., 90],
                                   [0, 'Firing rate (Hz)', -0.45, 0., 90],
                                   [0, r'$MSN_{D1}$', -0.27, 0.5, 90],
                                   [1, r'$MSN_{D2}$', -0.27, 0.5, 90],
#                                    [2, r'$MSN_{D1}$', 1.1, 0.5, 270],
#                                    [3, r'$MSN_{D2}$', 1.1, 0.5, 270],
                                   ]:
        axs[i].text(c0, c1, s, 
                    fontsize=7,
                    transform=axs[i].transAxes,
                    verticalalignment='center', 
                    horizontalalignment='center', 
                    rotation=rotation) 
        
    for i, ax in enumerate(axs[1:4]):
        ax.legend().set_visible(False)
        
    axs[0].legend(axs[0].lines[:], 
                  d_plot_fr2['labels'][:], 
                  bbox_to_anchor=(1.2, 1.8), ncol=1,
                  handletextpad=0.1,
                  frameon=False,
                  columnspacing=0.3,
                  labelspacing=0.2) 

    axs[0].my_remove_axis(xaxis=True, yaxis=False,
                          keep_ticks=False) 
    axs[2].my_remove_axis(xaxis=True, yaxis=False,
                          keep_ticks=False)
    for i, ax in enumerate(axs):
        ax.set_ylabel('')
        ax.my_set_no_ticks(yticks=3, xticks=3)    
        if i==0:
#             ax.set_xticks([700, 1000, 1300])
            ax.set_ylim([0,60])
            ax.set_xlim([700,1600])
        if i==1:
            ax.set_xticks([700, 1000, 1300])
            ax.set_xticklabels([0, 300, 600])
            ax.set_ylim([0,50])
            ax.set_xlim([700,1600])
        if i==2:
            ax.set_title('Difference')
#             ax.set_yticks([0, 1000.0])
            ax.set_ylim([0,30])
        if i==3:
#             ax.set_yticks([0, 1000.0])
            ax.set_ylim([0,25])
                        
    sd_figs.save_figs(figs, format='png', dpi=400)
    sd_figs.save_figs(figs, format='svg', in_folder='svg')
        print name, net
        if not (net in d[name].keys()):
            i+=1
            continue
        dd['Net_{:0>2}'.format(i)]=d[name][net]
        
        titles.append(name+'_'+net)
        i+=1 
pp(dd)

val=int(72/2.54*17.6*(1-17./48))
scale=1
fig, axs=ps.get_figure2(n_rows=11, 
                        n_cols=11,
                        w=val*scale,
                        h=300*scale,  
                        fontsize=7*scale,
                        title_fontsize=7*scale,
                        gs_builder=gs_builder) 

k={'axs':axs,
   'do_colorbar':False, 
   'fig':fig,
   'models':['SN'],
   'print_statistics':False,
   'resolution':10,
   'titles':['']*5*5,
    'type_of_plot':'mean',
    'vlim_rate':[-100, 100], 
    'marker_size':8}
    s = 'Net_{:0>2}'.format(i)
    dd[s] = d[name][net]
    translation[s] = name + '_' + net
    #     titles.append(name+'_'+net)
    i += 1
pp(dd)
pp(translation)
scale = 1
figs = []
fun_call = [show_heat_map, show_variability_several]

for iFig in range(2):
    fig, axs = ps.get_figure2(n_rows=9,
                              n_cols=10,
                              w=72 / 2.54 * 17.6 * scale,
                              h=72 / 2.54 * 17.6 * scale * 1.1,
                              fontsize=7 * scale,
                              title_fontsize=7 * scale,
                              gs_builder=gs_builder)

    for ax in axs:
        ax.tick_params(
            direction='in',
            length=0,
            width=0.5,
            #                        pad=1,
            top=False,
            right=False,
            left=False,
            bottom=False,
        )
Exemplo n.º 10
0
def plot_coher(d, labelsy, labelsx=[], title_name='Slow wave'):
    fig, axs=ps.get_figure2(n_rows=9, n_cols=12, 
                            #w=700, 
                            w=23/56.*17.6*72/2.54,
                            h=23/56.*17.6*72/2.54,
#                             h=700, 
                            fontsize=7,
                            frame_hight_y=0.5, frame_hight_x=0.7, 
                            title_fontsize=7,
                            gs_builder=gs_builder,
                            linewidth=1.) 
    for ax in axs:
        ax.tick_params(direction='in',
                       length=2, 
                       width=0.5
#                        top=False, right=False
                        )  
        
#     fig, axs=ps.get_figure(n_rows=4, n_cols=1, w=500.0*0.65*2, h=400.0*0.65*2, fontsize=16,
#                            frame_hight_y=0.6, frame_hight_x=0.8, 
#                            title_fontsize=20, text_usetex=False)        
    
#     nice_labels={'CTX_M1':r'CTX$\to$$MSN_{D1}$ ',
#                  'CTX_M2':r'CTX$\to$$MSN_{D2}$ ',
#                  'CTX_ST':r'CTX$\to$STN ',
#                  'MS_MS':r'MSN$\to$MSN ',
#                  'M1':r'$MSN_{D1}$ ',
#                  'M2':r'$MSN_{D2}$ ',
#                  'GP':r'GPe ',
#                  'SN':r'SNr ',
#                  'FS_M2':r'FSN$\to$$MSN_{D2}$ ',
#                  'GP_ST':r'GPe$\to$STN ',
#                  'GP_FS':r'GPe$\to$FSN ',
#                  'GP_GP':r'GPe$\to$GPe ',
#                  'FS_FS':r'FSN$\to$FSN ',
#                  'M1_SN':r'$MSN_{D1}$$\to$SNr ',
#                  'ST_GP':r'STN$\to$GPe ',}

    from scripts_inhibition.base_effect_conns import nice_labels
    groupings=['Coherence','Phase relation']
# 
#     nice_labels2={'GA_GA':r'TA vs TA',
#                   'GI_GA':r'TI vs TA',
#                   'GI_GI':r'TI vs TI',
#                   'GP_GP':r'GP vs GP'}

    for i in range(len(labelsy)):
        if labelsy[i] in nice_labels(version=0).keys():
            labelsy[i]=nice_labels(version=0)[labelsy[i]]
            
    l0=[]
    l2=[]
    for key in sorted(d['bar_obj'].keys()):
        
#         if key in ['GA_GA', 'GI_GA', 'GP_GP']:
#             d['bar_obj'][key].y[1,:]+=1
        l0.append(d['bar_obj'][key].y[0,:])
        l2.append(d['bar_obj'][key].y[1,:])
        labelsx.append(key)

    z=numpy.transpose(numpy.array(l0+l2))
    
    for i in range(len(labelsx)):
        if labelsx[i] in nice_labels(version=1).keys():
            labelsx[i]=nice_labels(version=1)[labelsx[i]]
            
     
    _vmin=0
    _vmax=4
    stepx=1
    stepy=1
    startx=0
    starty=0
    stopy=14
    stopx=8
    maxy=14
    maxx=8
    
    posy=numpy.linspace(0.5,maxy-0.5, maxy)
    posx=numpy.linspace(0.5,maxx-0.5, maxx)
    axs[1].barh(posy,numpy.mean(z,axis=1)[::-1], 
                align='center', 
                color='0.5',
                edgecolor='0.5'
                )
    axs[1].plot([1,1], [0,stopy], 'k', linestyle='--')
    x1,y1=numpy.meshgrid(numpy.linspace(startx, stopx, maxx+1),
                   numpy.linspace(stopy, starty, maxy+1))
    
    im = axs[0].pcolor(x1, y1, z, cmap='jet', 
                        vmin=_vmin, vmax=_vmax
                       )
    axs[0].set_yticks(posy)
    axs[0].set_yticklabels(labelsy[::-1])
    axs[0].set_xticks(posx)
    axs[0].set_xticklabels(labelsx*2, rotation=70, ha='right')
    axs[0].set_ylim([0,maxy])
    axs[0].text(0.25, -0.39, "Coherence",
                ha='center',
                 transform=axs[0].transAxes)
    axs[0].text(0.75, -0.39, "Phase shift", 
                ha='center', transform=axs[0].transAxes)
    axs[1].text(0.5, -0.15, "Mean", 
                transform=axs[1].transAxes,
                ha='center',
                rotation=0)
    axs[1].text(0.5, -0.22, "effect", 
                transform=axs[1].transAxes,
                ha='center',
                rotation=0)    
    font0 = FontProperties()
    font0.set_weight('bold')
    axs[1].text(1.3, 0.5, title_name,
#                 fontsize=28,
                va='center',
                 transform=axs[0].transAxes,
                                rotation=270,
                                fontproperties=font0)
    axs[0].text(-0.58, 0.5, "Connection without dop. effect", transform=axs[0].transAxes,
                rotation=90,
                ha='center',
                va='center')
        
    axs[1].my_remove_axis(xaxis=False, yaxis=True)
    axs[1].my_set_no_ticks(xticks=2)
    axs[1].set_xlim([0,maxy])
    axs[1].set_xlim([0,2])
#     axs[1].set_xticks([0.04, 0.12])


    box = axs[0].get_position()
    axColor=pylab.axes([box.x0+0.0*box.width, 
                        box.y0+box.height+box.height*0.15, 
                        box.width*0.8, 
                        0.02])
    #     axColor = pylab.axes([0.05, 0.9, 1.0, 0.05])
    cbar=pylab.colorbar(im, cax = axColor, orientation="horizontal")
    cbar.ax.set_title('Deviation from base model MSE')#, rotation=270)
    cbar.ax.tick_params(direction='in',
                       length=1, 
                       width=0.5
#                        top=False, right=False
                        )  
    from matplotlib import ticker
    tick_locator = ticker.MaxNLocator(nbins=4)
    cbar.locator = tick_locator
    cbar.update_ticks()

    
    return fig
Exemplo n.º 11
0
def create_figs(file_name_figs, from_disks, d, models, setup):
    sd_figs = Storage_dic.load(file_name_figs)

#     d_plot_fr = setup.plot_fr()
    d_plot_mr = setup.plot_mr()
    d_plot_mr2 = setup.plot_mr2()
    figs = []
    
    pp(setup.plot_mr_general().get('fig_and_axes'))
    fig, axs=ps.get_figure2(**setup.plot_mr_general().get('fig_and_axes'))
    figs.append(fig)
    for ax in axs:
        ax.tick_params(direction='out',
                       length=2,
                       width=0.5,
                       pad=0.01,
                        top=False, right=False
                        )
    
    show_mr(d, ['M1', 'M2'], axs[2:4], **d_plot_mr2)


    for i, s, c0, c1, rotation in [[2, 'Rel. inh. effect', -0.31, 0., 90],
#                                    [2, 'decrease (Hz)', -0.39, 0., 90],
                                   [0, 'Firing rate (Hz)', -0.45, 0., 90],
                                   [0, r'$MSN_{D1}$', -0.27, 0.5, 90],
                                   [1, r'$MSN_{D2}$', -0.27, 0.5, 90],
#                                    [2, r'$MSN_{D1}$', 1.1, 0.5, 270],
#                                    [3, r'$MSN_{D2}$', 1.1, 0.5, 270],
                                   ]:
                                   
        axs[i].text(c0, c1, s, 
                    fontsize=7,
                    transform=axs[i].transAxes,
                    verticalalignment='center', 
                    horizontalalignment='center', 
                    rotation=rotation) 
   
#     ps.shift('left', axs, 0.5, n_rows=len(axs), n_cols=1)
    for ax in axs[2:4]:
        if not ax.legend():
            continue
#         ax.legend(bbox_to_anchor=(2.2, 1))
        ax.legend().set_visible(False)
#         ax.set_ylabel('')
#         ax.set_xlim([1,2])
#     figs.append(show_fr(d, models, **d_plot_fr))
#     figs.append(show_mr(d, models, **d_plot_mr))
#     figs.append(show_mr(d, models, **d_plot_mr2))
    show_mr(d, ['M1', 'M2'],  axs[0:2], **d_plot_mr)
#     axs=figs[-1].get_axes()
#     ps.shift('left', axs, 0.5, n_rows=len(axs), n_cols=1)
     
    axs[0].legend(axs[0].lines[0:6], 
                  d_plot_mr2['labels'][0:6], 
                  bbox_to_anchor=(2.4, 2.4), ncol=2,
                  handletextpad=0.1,
                  frameon=False,
                  columnspacing=0.3,
                  labelspacing=0.2) 
    
    for i, ax in enumerate(axs[1:2]):
        if ax.legend():
            ax.legend().set_visible(False)
#         ax.legend(bbox_to_anchor=(2.2, 1))
#         ax.set_xlim([1,1.5])
        if i==1:
            ax.set_ylim([0,20])
    axs[0].my_remove_axis(xaxis=True, yaxis=False,
                          keep_ticks=False) 
    axs[2].my_remove_axis(xaxis=True, yaxis=False,
                          keep_ticks=False)
    for i, ax in enumerate(axs):
        ax.set_ylabel('')
        ax.my_set_no_ticks(yticks=3, xticks=4)
#         ax.set_xticks([1.1,1.3,1.5])
        if i==0:
            ax.set_yticks([0,10,20])
        if i==1:
            ax.set_yticks([0,7,14])
            ax.set_ylim([0,20])
    
    if len(axs[2].lines)>2:
        axs[2].lines.remove(axs[2].lines[0])
        axs[2].lines.remove(axs[2].lines[-1])
            
    sd_figs.save_figs(figs, format='png', dpi=100)
    sd_figs.save_figs(figs, format='svg', in_folder='svg')
Exemplo n.º 12
0
def show_heat_map(d, attr, **k):
    do_colorbar=k.get('do_colorbar',True)
    models=['SN']
    print_statistics=k.get('print_statistics', True)
    res=k.get('resolution')
    titles=k.get('titles')
    vlim_variance=k.get('vlim_variance')
    vlim_CV=k.get('vlim_CV')
    vlim_rate=k.get('vlim_rate')

    axs=k.get('axs')
    fig=k.get('fig')
    if not axs or not fig:

        fig, axs=ps.get_figure2(n_rows=3, 
                                n_cols=8,  
                                w=780/(24./7.),
                                h=910/(24./7.),  
                                fontsize=7,
                                title_fontsize=7,
                                gs_builder=gs_builder) 
  
        k['fig']=fig
        k['axs']=axs
  
    type_of_plot=k.get('type_of_plot', 'mean')
     
    i=0
    performance={}
    m=len(d.keys())
    
    for model in models:

        for key in sorted(d.keys()):
            print key
            i=int(key.split('_')[1])
            obj0=d[key]['set_0'][model][attr]
            obj1=d[key]['set_1'][model][attr]
            args=[obj0.x_set, obj1.x_set,
                  
                  numpy.mean(obj1.y_raw_data-obj0.y_raw_data, axis=0),
                  numpy.std(obj1.y_raw_data-obj0.y_raw_data, axis=0),
                  numpy.mean(obj0.y_raw_data, axis=0),
                  numpy.mean(obj1.y_raw_data, axis=0)]
            
            for j, arg in enumerate(args):
                arg.shape
                args[j]=numpy.reshape(arg, [res,res])
            x, y, z, z_std, d0, d1=args
            if type_of_plot=='variance':
                z=z_std
                _vmin, _vmax=vlim_variance

            elif type_of_plot=='CV':
                z=z_std/numpy.abs(z)
                _vmin, _vmax=vlim_CV
            else:
                _vmin, _vmax=vlim_rate
            
            
            stepx=(x[0,-1]-x[0,0])/res
            stepy=(y[-1,0]-y[0,0])/res
            x1,y1=numpy.meshgrid(numpy.linspace(x[0,0], x[0,-1], res+1),
                               numpy.linspace(y[0,0], y[-1,0], res+1))
            x2,y2=numpy.meshgrid(numpy.linspace(x[0,0]+stepx/2, 
                                                x[0,-1]-stepx/2, res),
                                 numpy.linspace(y[0,0]+stepy/2, 
                                                y[-1,0]-stepy/2, res))
            
            thr=k.get('threshold',14)
            
            im = axs[i].pcolor(x1, y1, z, cmap='coolwarm', 
                               vmin=_vmin, vmax=_vmax)
        
           
            for m in ['d', 's', 'n']:
                set_action_selection_marker(i,  d0, d1, 
                                                  x2, y2, thr, 
                                                 marker=m, 
                                                 **k)
             
            performance[key] = numpy.round(numpy.sum(numpy.abs(z)), 2)
            
            box = axs[i].get_position()
            

            axs[i].set_xlabel('Action 1')
            axs[i].set_ylabel('Action 2')
            axs[i].set_xlim([x[0,0], x[0,-1]])
            axs[i].set_ylim([y[0,0], y[-1,0]])
            # create color bar
            
            if key=='Net_0' and do_colorbar:
                label='Contrast (spike/s)'
                box = axs[0].get_position()
                axColor = pylab.axes([box.x0 + box.width * 2.56, 
                                      box.y0+box.height*0.1, 
                                      0.02, 
                                      box.height*0.8])
                cbar=pylab.colorbar(im, cax = axColor, orientation="vertical")
                cbar.ax.set_ylabel(label, rotation=270)
                cbar.set_ticks([_vmin,_vmax])
#                 from matplotlib import ticker
#                 
#                 tick_locator = ticker.MaxNLocator(nbins=3)
#                 cbar.locator = tick_locator
#                 cbar.update_ticks()
            
            axs[i].text(0.5, k.get('pos_ax_titles',1.05) , titles[i],
                        horizontalalignment='center', 
                        transform=axs[i].transAxes,
                        fontsize=k.get('fontsize_ax_titles',7)) 
#             axs[i].set_title(titles[i])

#             i+=1
    if print_statistics:
        ax=axs[-1]
        import pprint
        ax.text( 0.1, 0.1, pprint.pformat(performance), 
                       transform=ax.transAxes, 
                fontsize=7)
        
    return fig, performance
Exemplo n.º 13
0
def show_rate_D1_D2_SNR(d, d_plot_fr2):
        fig, axs=ps.get_figure2(**d_plot_fr2.get('fig_and_axes'))

        pp(d)
        show_fr(d, ['M1', 'M2', 'SN'], axs, **d_plot_fr2)
        
        for i, s, c0, c1, rotation in [
                                   [1, 'Firing rate (Hz)', -0.45, 0.5, 90],
                                   [0, r'$D1$', -0.25, 0.5, 90],
                                   [1, r'$D2$', -0.25, 0.5, 90],
                                   [2, r'SNr', -0.25, 0.5, 90],
                                   ]:
                                   
            axs[i].text(c0, c1, s, 
                        fontsize=7,
                        transform=axs[i].transAxes,
                        verticalalignment='center', 
                        horizontalalignment='center', 
                        rotation=rotation) 
        
        axs[0].legend(axs[0].lines[0:6], 
                  ['Action 1', 'Action 2'], 
                  bbox_to_anchor=(1.15, 1.85), 
                  ncol=2,
                  handletextpad=0.1,
                  
                  frameon=False,
                  columnspacing=0.3,
                  labelspacing=0.2) 
        
        for i, ax in enumerate(axs):      
            
            ax.my_set_no_ticks(xticks=3, yticks=2)
            if i==2:
                ax.set_yticks([0,60])
                ax.set_xlabel('')
                ax.text(0.5,-0.65,'Time (ms)',
                        fontsize=7,
                        transform =ax.transAxes,
                        ha='center',va='center') 
            if i==1:
                ax.set_yticks([0,20])
# 
            if i==0:
                ax.set_yticks([0,30])#           

            ax.set_ylabel('')
            axs[0].my_remove_axis(xaxis=True, yaxis=False,
                                  keep_ticks=True) 
            axs[1].my_remove_axis(xaxis=True, yaxis=False,
                                  keep_ticks=True)
            
        for ax in axs:
            ax.tick_params(direction='out',
                           length=1,
                           width=0.5,
                            top=False, right=False
                            ) 
        #think you have to to this after toget it to work
        for ax in axs:
            ax.tick_params(
                           pad=1,
                            ) 
        
            
        for i, ax in enumerate(axs[1:3]):
            if not ax.legend():
                continue
            ax.legend().set_visible(False)         
            
        return fig
Exemplo n.º 14
0
def create_figs(setup, file_name_figs, d, models):
    
    sd_figs = Storage_dic.load(file_name_figs)
    figs = []
    d_plot_3d=setup.plot_3d()

    
    d_plot_fr = setup.plot_fr()
    d_plot_fr2 = setup.plot_fr2()
    
    d_plot_bulk=setup.plot_bulkt()
    l_mean_rate_slices=setup.get_l_mean_rate_slices()
    
    
    figs.append(plot_optimal(setup.res))


    fig, axs=ps.get_figure2(n_rows=3, 
                            n_cols=8,  
                            w=780/(24./7.),
                            h=910/(24./7.),  
                            fontsize=7,
                            title_fontsize=7,
                            gs_builder=gs_builder) 

    pp(d)
    for name in l_mean_rate_slices:
        d_plot_3d['type_of_plot']='mean'
        fig, perf=show_heat_map(d, name,  **d_plot_3d)
        figs.append(fig)
   
    for name in l_mean_rate_slices:
        d_plot_3d['type_of_plot']='variance'
          
        fig, perf=show_heat_map(d, name,  **d_plot_3d)
        figs.append(fig)
  
    for name in l_mean_rate_slices:
        d_plot_3d['type_of_plot']='CV'
           
        fig, perf=show_heat_map(d, name,  **d_plot_3d)
        figs.append(fig)
  
    for name in l_mean_rate_slices:   
        figs.append(show_bulk(d, models, name, **d_plot_bulk))
#     pylab.show()         
    for i in range(len(d.keys())):
        
        if i!=1 and 1<len(d.keys()):
            continue
        
        
        try:figs.append(show_rate_D1_D2_SNR(d['Net_'+str(i)], d_plot_fr2))
        except: warnings.warn('failed rate plot')
        try:figs.append(show_rate_all(d['Net_'+str(i)], d_plot_fr))
        except: warnings.warn('failed rate plot')
#         figs.append()
        
        
    sd_figs.save_figs(figs, format='png', dpi=400)
    sd_figs.save_figs(figs, format='svg', in_folder='svg')
Exemplo n.º 15
0
        'add_midpoint':False,
        'psd':{'NFFT':128*4, 
               'fs':256.*4, 
               'noverlap':128*4/2},
        'oi_min':.5,
        'oi_max':1.5,
        'oi_upper':1000.,
        'oi_fs':256*4,
                   'keep':['data']}
    return d

fig, axs=ps.get_figure2(n_rows=4,
                         n_cols=3,  
                         w=int(72/2.54*11.6*(1+1./2+0.2))*scale,
                         h=int(0.85*72/2.54*11.6*(1+1./2))*scale,
#                             w=k.get('w',500), 
#                             h=k.get('h',900), 
                        linewidth=1,
                        fontsize=7*scale,
                        title_fontsize=7*scale,
                        gs_builder=gs_builder) 

models=['M1', 'M2', 'FS', 'GA', 'GI', 'GP', 'ST','SN',
                   'GI_ST', 'GP_GP', 'GA_GA', 'GI_GA', 'GI_GI']

# models=[m for m in models if not ( m in exclude)]

nets=['Net_0', 'Net_1']
attrs=[
       'firing_rate', 
       'mean_coherence', 
       'phases_diff_with_cohere',
        dd['Net_{:0>2}'.format(i)] = d[name][net]

        #         titles.append(name+'_'+net)
        i += 1
# for name, nets in builder:
#     for net in nets:
#         dd['Net_{}'.format(i)]=d[name][net]
#         i+=1
pp(dd)
print len(dd['Net_00']['set_1']['SN']['mean_rate_slices'].y)

figs = []
fig, axs = ps.get_figure2(n_rows=19,
                          n_cols=16,
                          w=int(72 / 2.54 * 17.6 * (17. / 48)),
                          h=300,
                          fontsize=7,
                          title_fontsize=7,
                          gs_builder=gs_builder)
figs.append(fig)
k = {
    'axs':
    axs,
    'do_colorbar':
    False,
    'fig':
    fig,
    'models': ['SN'],
    'print_statistics':
    False,
    'resolution':
Exemplo n.º 17
0
        dd['Net_{:0>2}'.format(i)]=d[name][net]
        
#         titles.append(name+'_'+net)
        i+=1 
# for name, nets in builder:
#     for net in nets:
#         dd['Net_{}'.format(i)]=d[name][net]
#         i+=1 
pp(dd)
print len(dd['Net_00']['set_1']['SN']['mean_rate_slices'].y)

figs=[]
fig, axs=ps.get_figure2(n_rows=19, 
                        n_cols=16,
                        w=int(72/2.54*17.6*(17./48)),
                        h=300,  
                        fontsize=7,
                        title_fontsize=7,
                        gs_builder=gs_builder) 
figs.append(fig)
k={'axs':axs,
   'do_colorbar':False, 
   'fig':fig,
   'models':['SN'],
   'print_statistics':False,
   'resolution':7,
   'threshold':thr,
   'titles':['Only D1',
             'D1 & D2',
             r'No MSN$\to$MSN',
             r'No FSN$\to$MSN',
Exemplo n.º 18
0
n=10
z=numpy.random.random((n,n))
z[1,:]+=1
z[8,:]+=2
_vmin=0
_vmax=4
stepx=1
stepy=1
startx=0
starty=0
stopy=10
stopx=10
res=10

fig, axs=ps.get_figure2(n_rows=5, n_cols=6, w=700, h=500, fontsize=24,
                        frame_hight_y=0.5, frame_hight_x=0.7, title_fontsize=20,
                        gs_builder=gs_builder)        

pos=numpy.linspace(0.5,9.5,10)
axs[1].barh(pos,numpy.mean(z,axis=1)[::-1], align='center')


# ax=pylab.subplot(111)
nets=['Net_'+str(i) for i in range(10)]
x1,y1=numpy.meshgrid(numpy.linspace(startx, stopx, res+1),
                   numpy.linspace(stopy, starty, res+1))
# x2,y2=numpy.meshgrid(numpy.linspace(startx+stepx/2, 
#                                     stopx-stepx/2, res),
#                      numpy.linspace(stopy+stepy/2, 
#                                     starty-stepy/2, res))