def show_rates(d, **k): scale=k.get('scale',2) kw={'n_rows':1, 'n_cols':5, 'w':int(72/2.54*17.6)*2*scale, 'h':int(72/2.54*17.6)*1.5*scale, 'fontsize':7*scale, 'title_fontsize':7*scale, 'frame_hight_y':0.5, 'frame_hight_x':0.7, 'gs_builder':gs_builder} fig, axs=ps.get_figure2(**kw) i=0 images=[] for tp in ['beta', 'sw']: for net in sorted(d[tp].keys()): ax=axs[i] dd=d[tp][net] im=plot_rates(ax, dd['mean_rates'], dd['xlabels'], dd['ylabels']) ax.set_title(tp+' '+net) im.set_clim([0, 35]) images.append(im) i+=1 for i in [1,3]: axs[i].my_remove_axis(xaxis=False, yaxis=True,keep_ticks=True) ax=axs[1] im=images[1] box = ax.get_position() axColor = pylab.axes([box.x0 + box.width *1.03, box.y0+box.height*0.1, 0.01, box.height*0.8]) cbar=pylab.colorbar(im, cax = axColor, orientation="vertical") cbar.ax.set_ylabel('Rate (Hz)', rotation=270) from matplotlib import ticker tick_locator = ticker.MaxNLocator(nbins=4) cbar.locator = tick_locator cbar.update_ticks() cbar.ax.tick_params(direction='in', length=1, width=0.5) return fig
def get_fig_axs(scale=4): kw={'n_rows':2, 'n_cols':1, 'w':72/2.54*11.6*scale, 'h':150*scale, 'fontsize':7*scale, 'frame_hight_y':0.5, 'frame_hight_x':0.7, 'title_fontsize':7*scale, 'font_size':7*scale, 'text_fontsize':7*scale, 'linewidth':1.*scale, 'gs_builder':gs_builder} # kwargs_fig=kwargs.get('kwargs_fig', kw) from core import plot_settings as ps fig, axs=ps.get_figure2(**kw) return fig, axs
def get_fig_axs(): scale = 4 kw = { 'n_rows': 4, 'n_cols': 2, 'w': 72 / 2.54 * 18 * scale, 'h': 300 * scale, 'fontsize': 7 * scale, 'frame_hight_y': 0.5, 'frame_hight_x': 0.7, 'title_fontsize': 7 * scale, 'font_size': 7 * scale, 'text_fontsize': 7 * scale, 'linewidth': 1. * scale, 'gs_builder': gs_builder } # kwargs_fig=kwargs.get('kwargs_fig', kw) from core import plot_settings as ps fig, axs = ps.get_figure2(**kw) return fig, axs
'fs':256.*4, 'noverlap':128*4/2}, 'oi_min':.5, 'oi_max':1.5, 'oi_upper':1000., 'oi_fs':256*4, 'keep':['data'], 'compute_performance_name_and_x':lambda x: [x[0],1] } return d fig, axs=ps.get_figure2(n_rows=4, n_cols=3, w=int(72/2.54*11.6*(1+1./2+0.2))*scale, h=int(0.85*72/2.54*11.6*(1+1./2))*scale, # w=k.get('w',500), # h=k.get('h',900), linewidth=1, fontsize=7*scale, title_fontsize=7*scale, gs_builder=gs_builder) models=['M1', 'M2', 'FS', 'GA', 'GF', 'GI', 'GP', 'ST','SN', 'GI_MS', 'GA_MS', 'GF_MS', 'GI_FS', 'GA_FS', 'GF_FS'] # models=[m for m in models if not ( m in exclude)]
def plot_coher(d, labelsy, labelsx=[], title_name='Slow wave'): fig, axs = ps.get_figure2( n_rows=9, n_cols=12, #w=700, w=23 / 56. * 17.6 * 72 / 2.54, h=23 / 56. * 17.6 * 72 / 2.54, # h=700, fontsize=7, frame_hight_y=0.5, frame_hight_x=0.7, title_fontsize=7, gs_builder=gs_builder, linewidth=1.) for ax in axs: ax.tick_params(direction='in', length=2, width=0.5 # top=False, right=False ) # fig, axs=ps.get_figure(n_rows=4, n_cols=1, w=500.0*0.65*2, h=400.0*0.65*2, fontsize=16, # frame_hight_y=0.6, frame_hight_x=0.8, # title_fontsize=20, text_usetex=False) # nice_labels={'CTX_M1':r'CTX$\to$$MSN_{D1}$ ', # 'CTX_M2':r'CTX$\to$$MSN_{D2}$ ', # 'CTX_ST':r'CTX$\to$STN ', # 'MS_MS':r'MSN$\to$MSN ', # 'M1':r'$MSN_{D1}$ ', # 'M2':r'$MSN_{D2}$ ', # 'GP':r'GPe ', # 'SN':r'SNr ', # 'FS_M2':r'FSN$\to$$MSN_{D2}$ ', # 'GP_ST':r'GPe$\to$STN ', # 'GP_FS':r'GPe$\to$FSN ', # 'GP_GP':r'GPe$\to$GPe ', # 'FS_FS':r'FSN$\to$FSN ', # 'M1_SN':r'$MSN_{D1}$$\to$SNr ', # 'ST_GP':r'STN$\to$GPe ',} from scripts_inhibition.base_effect_conns import nice_labels groupings = ['Coherence', 'Phase relation'] # # nice_labels2={'GA_GA':r'TA vs TA', # 'GI_GA':r'TI vs TA', # 'GI_GI':r'TI vs TI', # 'GP_GP':r'GP vs GP'} for i in range(len(labelsy)): if labelsy[i] in nice_labels(version=0).keys(): labelsy[i] = nice_labels(version=0)[labelsy[i]] l0 = [] l2 = [] for key in sorted(d['bar_obj'].keys()): # if key in ['GA_GA', 'GI_GA', 'GP_GP']: # d['bar_obj'][key].y[1,:]+=1 l0.append(d['bar_obj'][key].y[0, :]) l2.append(d['bar_obj'][key].y[1, :]) labelsx.append(key) z = numpy.transpose(numpy.array(l0 + l2)) for i in range(len(labelsx)): if labelsx[i] in nice_labels(version=1).keys(): labelsx[i] = nice_labels(version=1)[labelsx[i]] _vmin = 0 _vmax = 4 stepx = 1 stepy = 1 startx = 0 starty = 0 stopy = 14 stopx = 8 maxy = 14 maxx = 8 posy = numpy.linspace(0.5, maxy - 0.5, maxy) posx = numpy.linspace(0.5, maxx - 0.5, maxx) axs[1].barh(posy, numpy.mean(z, axis=1)[::-1], align='center', color='0.5', edgecolor='0.5') axs[1].plot([1, 1], [0, stopy], 'k', linestyle='--') x1, y1 = numpy.meshgrid(numpy.linspace(startx, stopx, maxx + 1), numpy.linspace(stopy, starty, maxy + 1)) im = axs[0].pcolor(x1, y1, z, cmap='jet', vmin=_vmin, vmax=_vmax) axs[0].set_yticks(posy) axs[0].set_yticklabels(labelsy[::-1]) axs[0].set_xticks(posx) axs[0].set_xticklabels(labelsx * 2, rotation=70, ha='right') axs[0].set_ylim([0, maxy]) axs[0].text(0.25, -0.39, "Coherence", ha='center', transform=axs[0].transAxes) axs[0].text(0.75, -0.39, "Phase shift", ha='center', transform=axs[0].transAxes) axs[1].text(0.5, -0.15, "Mean", transform=axs[1].transAxes, ha='center', rotation=0) axs[1].text(0.5, -0.22, "effect", transform=axs[1].transAxes, ha='center', rotation=0) font0 = FontProperties() font0.set_weight('bold') axs[1].text( 1.3, 0.5, title_name, # fontsize=28, va='center', transform=axs[0].transAxes, rotation=270, fontproperties=font0) axs[0].text(-0.58, 0.5, "Connection without dop. effect", transform=axs[0].transAxes, rotation=90, ha='center', va='center') axs[1].my_remove_axis(xaxis=False, yaxis=True) axs[1].my_set_no_ticks(xticks=2) axs[1].set_xlim([0, maxy]) axs[1].set_xlim([0, 2]) # axs[1].set_xticks([0.04, 0.12]) box = axs[0].get_position() axColor = pylab.axes([ box.x0 + 0.0 * box.width, box.y0 + box.height + box.height * 0.15, box.width * 0.8, 0.02 ]) # axColor = pylab.axes([0.05, 0.9, 1.0, 0.05]) cbar = pylab.colorbar(im, cax=axColor, orientation="horizontal") cbar.ax.set_title('Deviation from base model MSE') #, rotation=270) cbar.ax.tick_params(direction='in', length=1, width=0.5 # top=False, right=False ) from matplotlib import ticker tick_locator = ticker.MaxNLocator(nbins=4) cbar.locator = tick_locator cbar.update_ticks() return fig
# # # # # # show_heat_map(dd, 'mean_rate_slices', **k) # # show_variability(dd, 'mean_rate_slices', **k) # # show_variability(dd, 'mean_rate_slices', net='Net_01',**k) # show_variability(dd, 'mean_rate_slices', net='Net_03',**k) # show_variability(dd, 'mean_rate_slices', net='Net_05',**k) # show_variability(dd, 'mean_rate_slices', net='Net_06',**k) scale=4 fig, axs=ps.get_figure2(n_rows=4, n_cols=2, w=int(72/2.54*8.9) *scale, h=450*scale, fontsize=7*scale, title_fontsize=7*scale, # grid=[3,2], gs_builder=gs_builder_new) print len(axs) # pylab.show() k={'axs':axs, # 'do_colorbar':False, 'fig':fig, 'model':'SN', 'print_statistics':False, 'resolution':10,
def create_figs(setup, file_name_figs, d, models): sd_figs = Storage_dic.load(file_name_figs) figs = [] # d_plot_fr = setup.plot_fr() d_plot_fr2=setup.plot_fr2() # d_plot_mr=setup.plot_mr() d_plot_mr_diff=setup.plot_mr_diff() fig, axs=ps.get_figure2(**setup.plot_fig_axs().get('fig_and_axes')) figs.append(fig) for ax in axs: ax.tick_params(direction='out', length=2, width=0.5, pad=0.01, top=False, right=False ) # for name in sorted(d.keys()): # if name=='Difference': # continue # for model in d[name]['set_0'].keys(): # v={'firing_rate':d[name]['set_0'][model]['firing_rate']} # d['Net_0'][model]=v # figs.append(show_fr(d, models, **d_plot_fr)) show_fr(d, models, axs[0:2], **d_plot_fr2) axs=figs[-1].get_axes() for i, ax in enumerate(axs[2:4]): if i==0: ax.set_ylim([0,70]) else: ax.set_ylim([0,70]) ax.my_set_no_ticks(xticks=4) show_mr_diff(d, models, axs[2:4], **d_plot_mr_diff) # axs=figs[-1].get_axes() for i, ax in enumerate(axs[2:4]): handles, labels=ax.get_legend_handles_labels() ax.legend(handles, labels,loc='upper left') for i, s, c0, c1, rotation in [ # [2, 'Ralative', -0.58, 0., 90], # [2, 'decrease (Hz)', -0.39, 0., 90], [0, 'Firing rate (Hz)', -0.45, 0., 90], [0, r'$MSN_{D1}$', -0.27, 0.5, 90], [1, r'$MSN_{D2}$', -0.27, 0.5, 90], # [2, r'$MSN_{D1}$', 1.1, 0.5, 270], # [3, r'$MSN_{D2}$', 1.1, 0.5, 270], ]: axs[i].text(c0, c1, s, fontsize=7, transform=axs[i].transAxes, verticalalignment='center', horizontalalignment='center', rotation=rotation) for i, ax in enumerate(axs[1:4]): ax.legend().set_visible(False) axs[0].legend(axs[0].lines[:], d_plot_fr2['labels'][:], bbox_to_anchor=(1.2, 1.8), ncol=1, handletextpad=0.1, frameon=False, columnspacing=0.3, labelspacing=0.2) axs[0].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=False) axs[2].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=False) for i, ax in enumerate(axs): ax.set_ylabel('') ax.my_set_no_ticks(yticks=3, xticks=3) if i==0: # ax.set_xticks([700, 1000, 1300]) ax.set_ylim([0,60]) ax.set_xlim([700,1600]) if i==1: ax.set_xticks([700, 1000, 1300]) ax.set_xticklabels([0, 300, 600]) ax.set_ylim([0,50]) ax.set_xlim([700,1600]) if i==2: ax.set_title('Difference') # ax.set_yticks([0, 1000.0]) ax.set_ylim([0,30]) if i==3: # ax.set_yticks([0, 1000.0]) ax.set_ylim([0,25]) sd_figs.save_figs(figs, format='png', dpi=400) sd_figs.save_figs(figs, format='svg', in_folder='svg')
print name, net if not (net in d[name].keys()): i+=1 continue dd['Net_{:0>2}'.format(i)]=d[name][net] titles.append(name+'_'+net) i+=1 pp(dd) val=int(72/2.54*17.6*(1-17./48)) scale=1 fig, axs=ps.get_figure2(n_rows=11, n_cols=11, w=val*scale, h=300*scale, fontsize=7*scale, title_fontsize=7*scale, gs_builder=gs_builder) k={'axs':axs, 'do_colorbar':False, 'fig':fig, 'models':['SN'], 'print_statistics':False, 'resolution':10, 'titles':['']*5*5, 'type_of_plot':'mean', 'vlim_rate':[-100, 100], 'marker_size':8}
s = 'Net_{:0>2}'.format(i) dd[s] = d[name][net] translation[s] = name + '_' + net # titles.append(name+'_'+net) i += 1 pp(dd) pp(translation) scale = 1 figs = [] fun_call = [show_heat_map, show_variability_several] for iFig in range(2): fig, axs = ps.get_figure2(n_rows=9, n_cols=10, w=72 / 2.54 * 17.6 * scale, h=72 / 2.54 * 17.6 * scale * 1.1, fontsize=7 * scale, title_fontsize=7 * scale, gs_builder=gs_builder) for ax in axs: ax.tick_params( direction='in', length=0, width=0.5, # pad=1, top=False, right=False, left=False, bottom=False, )
def plot_coher(d, labelsy, labelsx=[], title_name='Slow wave'): fig, axs=ps.get_figure2(n_rows=9, n_cols=12, #w=700, w=23/56.*17.6*72/2.54, h=23/56.*17.6*72/2.54, # h=700, fontsize=7, frame_hight_y=0.5, frame_hight_x=0.7, title_fontsize=7, gs_builder=gs_builder, linewidth=1.) for ax in axs: ax.tick_params(direction='in', length=2, width=0.5 # top=False, right=False ) # fig, axs=ps.get_figure(n_rows=4, n_cols=1, w=500.0*0.65*2, h=400.0*0.65*2, fontsize=16, # frame_hight_y=0.6, frame_hight_x=0.8, # title_fontsize=20, text_usetex=False) # nice_labels={'CTX_M1':r'CTX$\to$$MSN_{D1}$ ', # 'CTX_M2':r'CTX$\to$$MSN_{D2}$ ', # 'CTX_ST':r'CTX$\to$STN ', # 'MS_MS':r'MSN$\to$MSN ', # 'M1':r'$MSN_{D1}$ ', # 'M2':r'$MSN_{D2}$ ', # 'GP':r'GPe ', # 'SN':r'SNr ', # 'FS_M2':r'FSN$\to$$MSN_{D2}$ ', # 'GP_ST':r'GPe$\to$STN ', # 'GP_FS':r'GPe$\to$FSN ', # 'GP_GP':r'GPe$\to$GPe ', # 'FS_FS':r'FSN$\to$FSN ', # 'M1_SN':r'$MSN_{D1}$$\to$SNr ', # 'ST_GP':r'STN$\to$GPe ',} from scripts_inhibition.base_effect_conns import nice_labels groupings=['Coherence','Phase relation'] # # nice_labels2={'GA_GA':r'TA vs TA', # 'GI_GA':r'TI vs TA', # 'GI_GI':r'TI vs TI', # 'GP_GP':r'GP vs GP'} for i in range(len(labelsy)): if labelsy[i] in nice_labels(version=0).keys(): labelsy[i]=nice_labels(version=0)[labelsy[i]] l0=[] l2=[] for key in sorted(d['bar_obj'].keys()): # if key in ['GA_GA', 'GI_GA', 'GP_GP']: # d['bar_obj'][key].y[1,:]+=1 l0.append(d['bar_obj'][key].y[0,:]) l2.append(d['bar_obj'][key].y[1,:]) labelsx.append(key) z=numpy.transpose(numpy.array(l0+l2)) for i in range(len(labelsx)): if labelsx[i] in nice_labels(version=1).keys(): labelsx[i]=nice_labels(version=1)[labelsx[i]] _vmin=0 _vmax=4 stepx=1 stepy=1 startx=0 starty=0 stopy=14 stopx=8 maxy=14 maxx=8 posy=numpy.linspace(0.5,maxy-0.5, maxy) posx=numpy.linspace(0.5,maxx-0.5, maxx) axs[1].barh(posy,numpy.mean(z,axis=1)[::-1], align='center', color='0.5', edgecolor='0.5' ) axs[1].plot([1,1], [0,stopy], 'k', linestyle='--') x1,y1=numpy.meshgrid(numpy.linspace(startx, stopx, maxx+1), numpy.linspace(stopy, starty, maxy+1)) im = axs[0].pcolor(x1, y1, z, cmap='jet', vmin=_vmin, vmax=_vmax ) axs[0].set_yticks(posy) axs[0].set_yticklabels(labelsy[::-1]) axs[0].set_xticks(posx) axs[0].set_xticklabels(labelsx*2, rotation=70, ha='right') axs[0].set_ylim([0,maxy]) axs[0].text(0.25, -0.39, "Coherence", ha='center', transform=axs[0].transAxes) axs[0].text(0.75, -0.39, "Phase shift", ha='center', transform=axs[0].transAxes) axs[1].text(0.5, -0.15, "Mean", transform=axs[1].transAxes, ha='center', rotation=0) axs[1].text(0.5, -0.22, "effect", transform=axs[1].transAxes, ha='center', rotation=0) font0 = FontProperties() font0.set_weight('bold') axs[1].text(1.3, 0.5, title_name, # fontsize=28, va='center', transform=axs[0].transAxes, rotation=270, fontproperties=font0) axs[0].text(-0.58, 0.5, "Connection without dop. effect", transform=axs[0].transAxes, rotation=90, ha='center', va='center') axs[1].my_remove_axis(xaxis=False, yaxis=True) axs[1].my_set_no_ticks(xticks=2) axs[1].set_xlim([0,maxy]) axs[1].set_xlim([0,2]) # axs[1].set_xticks([0.04, 0.12]) box = axs[0].get_position() axColor=pylab.axes([box.x0+0.0*box.width, box.y0+box.height+box.height*0.15, box.width*0.8, 0.02]) # axColor = pylab.axes([0.05, 0.9, 1.0, 0.05]) cbar=pylab.colorbar(im, cax = axColor, orientation="horizontal") cbar.ax.set_title('Deviation from base model MSE')#, rotation=270) cbar.ax.tick_params(direction='in', length=1, width=0.5 # top=False, right=False ) from matplotlib import ticker tick_locator = ticker.MaxNLocator(nbins=4) cbar.locator = tick_locator cbar.update_ticks() return fig
def create_figs(file_name_figs, from_disks, d, models, setup): sd_figs = Storage_dic.load(file_name_figs) # d_plot_fr = setup.plot_fr() d_plot_mr = setup.plot_mr() d_plot_mr2 = setup.plot_mr2() figs = [] pp(setup.plot_mr_general().get('fig_and_axes')) fig, axs=ps.get_figure2(**setup.plot_mr_general().get('fig_and_axes')) figs.append(fig) for ax in axs: ax.tick_params(direction='out', length=2, width=0.5, pad=0.01, top=False, right=False ) show_mr(d, ['M1', 'M2'], axs[2:4], **d_plot_mr2) for i, s, c0, c1, rotation in [[2, 'Rel. inh. effect', -0.31, 0., 90], # [2, 'decrease (Hz)', -0.39, 0., 90], [0, 'Firing rate (Hz)', -0.45, 0., 90], [0, r'$MSN_{D1}$', -0.27, 0.5, 90], [1, r'$MSN_{D2}$', -0.27, 0.5, 90], # [2, r'$MSN_{D1}$', 1.1, 0.5, 270], # [3, r'$MSN_{D2}$', 1.1, 0.5, 270], ]: axs[i].text(c0, c1, s, fontsize=7, transform=axs[i].transAxes, verticalalignment='center', horizontalalignment='center', rotation=rotation) # ps.shift('left', axs, 0.5, n_rows=len(axs), n_cols=1) for ax in axs[2:4]: if not ax.legend(): continue # ax.legend(bbox_to_anchor=(2.2, 1)) ax.legend().set_visible(False) # ax.set_ylabel('') # ax.set_xlim([1,2]) # figs.append(show_fr(d, models, **d_plot_fr)) # figs.append(show_mr(d, models, **d_plot_mr)) # figs.append(show_mr(d, models, **d_plot_mr2)) show_mr(d, ['M1', 'M2'], axs[0:2], **d_plot_mr) # axs=figs[-1].get_axes() # ps.shift('left', axs, 0.5, n_rows=len(axs), n_cols=1) axs[0].legend(axs[0].lines[0:6], d_plot_mr2['labels'][0:6], bbox_to_anchor=(2.4, 2.4), ncol=2, handletextpad=0.1, frameon=False, columnspacing=0.3, labelspacing=0.2) for i, ax in enumerate(axs[1:2]): if ax.legend(): ax.legend().set_visible(False) # ax.legend(bbox_to_anchor=(2.2, 1)) # ax.set_xlim([1,1.5]) if i==1: ax.set_ylim([0,20]) axs[0].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=False) axs[2].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=False) for i, ax in enumerate(axs): ax.set_ylabel('') ax.my_set_no_ticks(yticks=3, xticks=4) # ax.set_xticks([1.1,1.3,1.5]) if i==0: ax.set_yticks([0,10,20]) if i==1: ax.set_yticks([0,7,14]) ax.set_ylim([0,20]) if len(axs[2].lines)>2: axs[2].lines.remove(axs[2].lines[0]) axs[2].lines.remove(axs[2].lines[-1]) sd_figs.save_figs(figs, format='png', dpi=100) sd_figs.save_figs(figs, format='svg', in_folder='svg')
def show_heat_map(d, attr, **k): do_colorbar=k.get('do_colorbar',True) models=['SN'] print_statistics=k.get('print_statistics', True) res=k.get('resolution') titles=k.get('titles') vlim_variance=k.get('vlim_variance') vlim_CV=k.get('vlim_CV') vlim_rate=k.get('vlim_rate') axs=k.get('axs') fig=k.get('fig') if not axs or not fig: fig, axs=ps.get_figure2(n_rows=3, n_cols=8, w=780/(24./7.), h=910/(24./7.), fontsize=7, title_fontsize=7, gs_builder=gs_builder) k['fig']=fig k['axs']=axs type_of_plot=k.get('type_of_plot', 'mean') i=0 performance={} m=len(d.keys()) for model in models: for key in sorted(d.keys()): print key i=int(key.split('_')[1]) obj0=d[key]['set_0'][model][attr] obj1=d[key]['set_1'][model][attr] args=[obj0.x_set, obj1.x_set, numpy.mean(obj1.y_raw_data-obj0.y_raw_data, axis=0), numpy.std(obj1.y_raw_data-obj0.y_raw_data, axis=0), numpy.mean(obj0.y_raw_data, axis=0), numpy.mean(obj1.y_raw_data, axis=0)] for j, arg in enumerate(args): arg.shape args[j]=numpy.reshape(arg, [res,res]) x, y, z, z_std, d0, d1=args if type_of_plot=='variance': z=z_std _vmin, _vmax=vlim_variance elif type_of_plot=='CV': z=z_std/numpy.abs(z) _vmin, _vmax=vlim_CV else: _vmin, _vmax=vlim_rate stepx=(x[0,-1]-x[0,0])/res stepy=(y[-1,0]-y[0,0])/res x1,y1=numpy.meshgrid(numpy.linspace(x[0,0], x[0,-1], res+1), numpy.linspace(y[0,0], y[-1,0], res+1)) x2,y2=numpy.meshgrid(numpy.linspace(x[0,0]+stepx/2, x[0,-1]-stepx/2, res), numpy.linspace(y[0,0]+stepy/2, y[-1,0]-stepy/2, res)) thr=k.get('threshold',14) im = axs[i].pcolor(x1, y1, z, cmap='coolwarm', vmin=_vmin, vmax=_vmax) for m in ['d', 's', 'n']: set_action_selection_marker(i, d0, d1, x2, y2, thr, marker=m, **k) performance[key] = numpy.round(numpy.sum(numpy.abs(z)), 2) box = axs[i].get_position() axs[i].set_xlabel('Action 1') axs[i].set_ylabel('Action 2') axs[i].set_xlim([x[0,0], x[0,-1]]) axs[i].set_ylim([y[0,0], y[-1,0]]) # create color bar if key=='Net_0' and do_colorbar: label='Contrast (spike/s)' box = axs[0].get_position() axColor = pylab.axes([box.x0 + box.width * 2.56, box.y0+box.height*0.1, 0.02, box.height*0.8]) cbar=pylab.colorbar(im, cax = axColor, orientation="vertical") cbar.ax.set_ylabel(label, rotation=270) cbar.set_ticks([_vmin,_vmax]) # from matplotlib import ticker # # tick_locator = ticker.MaxNLocator(nbins=3) # cbar.locator = tick_locator # cbar.update_ticks() axs[i].text(0.5, k.get('pos_ax_titles',1.05) , titles[i], horizontalalignment='center', transform=axs[i].transAxes, fontsize=k.get('fontsize_ax_titles',7)) # axs[i].set_title(titles[i]) # i+=1 if print_statistics: ax=axs[-1] import pprint ax.text( 0.1, 0.1, pprint.pformat(performance), transform=ax.transAxes, fontsize=7) return fig, performance
def show_rate_D1_D2_SNR(d, d_plot_fr2): fig, axs=ps.get_figure2(**d_plot_fr2.get('fig_and_axes')) pp(d) show_fr(d, ['M1', 'M2', 'SN'], axs, **d_plot_fr2) for i, s, c0, c1, rotation in [ [1, 'Firing rate (Hz)', -0.45, 0.5, 90], [0, r'$D1$', -0.25, 0.5, 90], [1, r'$D2$', -0.25, 0.5, 90], [2, r'SNr', -0.25, 0.5, 90], ]: axs[i].text(c0, c1, s, fontsize=7, transform=axs[i].transAxes, verticalalignment='center', horizontalalignment='center', rotation=rotation) axs[0].legend(axs[0].lines[0:6], ['Action 1', 'Action 2'], bbox_to_anchor=(1.15, 1.85), ncol=2, handletextpad=0.1, frameon=False, columnspacing=0.3, labelspacing=0.2) for i, ax in enumerate(axs): ax.my_set_no_ticks(xticks=3, yticks=2) if i==2: ax.set_yticks([0,60]) ax.set_xlabel('') ax.text(0.5,-0.65,'Time (ms)', fontsize=7, transform =ax.transAxes, ha='center',va='center') if i==1: ax.set_yticks([0,20]) # if i==0: ax.set_yticks([0,30])# ax.set_ylabel('') axs[0].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=True) axs[1].my_remove_axis(xaxis=True, yaxis=False, keep_ticks=True) for ax in axs: ax.tick_params(direction='out', length=1, width=0.5, top=False, right=False ) #think you have to to this after toget it to work for ax in axs: ax.tick_params( pad=1, ) for i, ax in enumerate(axs[1:3]): if not ax.legend(): continue ax.legend().set_visible(False) return fig
def create_figs(setup, file_name_figs, d, models): sd_figs = Storage_dic.load(file_name_figs) figs = [] d_plot_3d=setup.plot_3d() d_plot_fr = setup.plot_fr() d_plot_fr2 = setup.plot_fr2() d_plot_bulk=setup.plot_bulkt() l_mean_rate_slices=setup.get_l_mean_rate_slices() figs.append(plot_optimal(setup.res)) fig, axs=ps.get_figure2(n_rows=3, n_cols=8, w=780/(24./7.), h=910/(24./7.), fontsize=7, title_fontsize=7, gs_builder=gs_builder) pp(d) for name in l_mean_rate_slices: d_plot_3d['type_of_plot']='mean' fig, perf=show_heat_map(d, name, **d_plot_3d) figs.append(fig) for name in l_mean_rate_slices: d_plot_3d['type_of_plot']='variance' fig, perf=show_heat_map(d, name, **d_plot_3d) figs.append(fig) for name in l_mean_rate_slices: d_plot_3d['type_of_plot']='CV' fig, perf=show_heat_map(d, name, **d_plot_3d) figs.append(fig) for name in l_mean_rate_slices: figs.append(show_bulk(d, models, name, **d_plot_bulk)) # pylab.show() for i in range(len(d.keys())): if i!=1 and 1<len(d.keys()): continue try:figs.append(show_rate_D1_D2_SNR(d['Net_'+str(i)], d_plot_fr2)) except: warnings.warn('failed rate plot') try:figs.append(show_rate_all(d['Net_'+str(i)], d_plot_fr)) except: warnings.warn('failed rate plot') # figs.append() sd_figs.save_figs(figs, format='png', dpi=400) sd_figs.save_figs(figs, format='svg', in_folder='svg')
'add_midpoint':False, 'psd':{'NFFT':128*4, 'fs':256.*4, 'noverlap':128*4/2}, 'oi_min':.5, 'oi_max':1.5, 'oi_upper':1000., 'oi_fs':256*4, 'keep':['data']} return d fig, axs=ps.get_figure2(n_rows=4, n_cols=3, w=int(72/2.54*11.6*(1+1./2+0.2))*scale, h=int(0.85*72/2.54*11.6*(1+1./2))*scale, # w=k.get('w',500), # h=k.get('h',900), linewidth=1, fontsize=7*scale, title_fontsize=7*scale, gs_builder=gs_builder) models=['M1', 'M2', 'FS', 'GA', 'GI', 'GP', 'ST','SN', 'GI_ST', 'GP_GP', 'GA_GA', 'GI_GA', 'GI_GI'] # models=[m for m in models if not ( m in exclude)] nets=['Net_0', 'Net_1'] attrs=[ 'firing_rate', 'mean_coherence', 'phases_diff_with_cohere',
dd['Net_{:0>2}'.format(i)] = d[name][net] # titles.append(name+'_'+net) i += 1 # for name, nets in builder: # for net in nets: # dd['Net_{}'.format(i)]=d[name][net] # i+=1 pp(dd) print len(dd['Net_00']['set_1']['SN']['mean_rate_slices'].y) figs = [] fig, axs = ps.get_figure2(n_rows=19, n_cols=16, w=int(72 / 2.54 * 17.6 * (17. / 48)), h=300, fontsize=7, title_fontsize=7, gs_builder=gs_builder) figs.append(fig) k = { 'axs': axs, 'do_colorbar': False, 'fig': fig, 'models': ['SN'], 'print_statistics': False, 'resolution':
dd['Net_{:0>2}'.format(i)]=d[name][net] # titles.append(name+'_'+net) i+=1 # for name, nets in builder: # for net in nets: # dd['Net_{}'.format(i)]=d[name][net] # i+=1 pp(dd) print len(dd['Net_00']['set_1']['SN']['mean_rate_slices'].y) figs=[] fig, axs=ps.get_figure2(n_rows=19, n_cols=16, w=int(72/2.54*17.6*(17./48)), h=300, fontsize=7, title_fontsize=7, gs_builder=gs_builder) figs.append(fig) k={'axs':axs, 'do_colorbar':False, 'fig':fig, 'models':['SN'], 'print_statistics':False, 'resolution':7, 'threshold':thr, 'titles':['Only D1', 'D1 & D2', r'No MSN$\to$MSN', r'No FSN$\to$MSN',
n=10 z=numpy.random.random((n,n)) z[1,:]+=1 z[8,:]+=2 _vmin=0 _vmax=4 stepx=1 stepy=1 startx=0 starty=0 stopy=10 stopx=10 res=10 fig, axs=ps.get_figure2(n_rows=5, n_cols=6, w=700, h=500, fontsize=24, frame_hight_y=0.5, frame_hight_x=0.7, title_fontsize=20, gs_builder=gs_builder) pos=numpy.linspace(0.5,9.5,10) axs[1].barh(pos,numpy.mean(z,axis=1)[::-1], align='center') # ax=pylab.subplot(111) nets=['Net_'+str(i) for i in range(10)] x1,y1=numpy.meshgrid(numpy.linspace(startx, stopx, res+1), numpy.linspace(stopy, starty, res+1)) # x2,y2=numpy.meshgrid(numpy.linspace(startx+stepx/2, # stopx-stepx/2, res), # numpy.linspace(stopy+stepy/2, # starty-stepy/2, res))