Ejemplo n.º 1
0
def plot_multi_scale_output_a(fig):    
    #get the mean somatic currents and voltages,
    #write pickles if they do not exist:
    if not os.path.isfile(os.path.join(params.savefolder, 'data_analysis',
                                       'meanInpCurrents.pickle')):
        meanInpCurrents = getMeanInpCurrents(params, params.n_rec_input_spikes,
                                        os.path.join(params.spike_output_path,
                                                     'population_input_spikes'))
        f = file(os.path.join(params.savefolder, 'data_analysis',
                              'meanInpCurrents.pickle'), 'wb')
        pickle.dump(meanInpCurrents, f)
        f.close()
    else:
        f = file(os.path.join(params.savefolder, 'data_analysis',
                              'meanInpCurrents.pickle'), 'rb')
        meanInpCurrents = pickle.load(f)
        f.close()

    if not os.path.isfile(os.path.join(params.savefolder, 'data_analysis',
                                       'meanVoltages.pickle')):
        meanVoltages = getMeanVoltages(params, params.n_rec_voltage,
                                       os.path.join(params.spike_output_path,
                                                       'voltages'))
        f = file(os.path.join(params.savefolder, 'data_analysis',
                              'meanVoltages.pickle'), 'wb')
        pickle.dump(meanVoltages, f)
        f.close()
    else:
        f = file(os.path.join(params.savefolder, 'data_analysis',
                              'meanVoltages.pickle'), 'rb')
        meanVoltages = pickle.load(f)
        f.close()
    

    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    show_ax_labels = True
    show_insets = False

    transient=200
    T=[800, 1000]
    T_inset=[900, 920]

    sep = 0.025/2 #0.017
    
    left = 0.075
    bottom = 0.55
    top = 0.975
    right = 0.95
    axwidth = 0.16
    numcols = 4
    insetwidth = axwidth/2
    insetheight = 0.5
    
    lefts = np.linspace(left, right-axwidth, numcols)
    
    
    #fig = plt.figure()
    ############################################################################ 
    # A part, plot spike rasters
    ############################################################################
    ax1 = fig.add_axes([lefts[0], bottom, axwidth, top-bottom])
    #fig.text(0.005,0.95,'a',fontsize=8, fontweight='demibold')
    if show_ax_labels:
        phlp.annotate_subplot(ax1, ncols=4, nrows=1.02, letter='A', )
    ax1.set_title('network activity')
    plt.locator_params(nbins=4)
    
    x, y = networkSim.get_xy(T, fraction=1)
    networkSim.plot_raster(ax1, T, x, y, markersize=0.2, marker='_', alpha=1.,
                           legend=False, pop_names=True, rasterized=False)
    phlp.remove_axis_junk(ax1)
    ax1.set_xlabel(r'$t$ (ms)', labelpad=0.1)
    ax1.set_ylabel('population', labelpad=0.1)
    
    # Inset
    if show_insets:
        ax2 = fig.add_axes([lefts[0]+axwidth-insetwidth, top-insetheight, insetwidth, insetheight])
        plt.locator_params(nbins=4)
        x, y = networkSim.get_xy(T_inset, fraction=0.4)
        networkSim.plot_raster(ax2, T_inset, x, y, markersize=0.25, alpha=1.,
                               legend=False)
        phlp.remove_axis_junk(ax2)
        ax2.set_xticks(T_inset)
        ax2.set_yticks([])
        ax2.set_yticklabels([])
        ax2.set_ylabel('')
        ax2.set_xlabel('')

    
    ############################################################################
    # B part, plot firing rates
    ############################################################################
    
    nrows = len(networkSim.X)-1
    high = top
    low = bottom
    thickn = (high-low) / nrows - sep
    bottoms = np.linspace(low, high-thickn, nrows)[::-1]
    
    x, y = networkSim.get_xy(T, fraction=1)  
    
    #dummy ax to put label in correct location
    ax_ = fig.add_axes([lefts[1], bottom, axwidth, top-bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='B')        
    
    for i, X in enumerate(networkSim.X[:-1]):
        ax3 = fig.add_axes([lefts[1], bottoms[i], axwidth, thickn])
        plt.locator_params(nbins=4)
        phlp.remove_axis_junk(ax3)
        networkSim.plot_f_rate(ax3, X, i, T, x, y, yscale='linear',
                               plottype='fill_between', show_label=False,
                               rasterized=False)
        ax3.yaxis.set_major_locator(plt.MaxNLocator(3))
        if i != nrows -1:    
            ax3.set_xticklabels([])
    
        if i == 3:
            ax3.set_ylabel(r'(s$^{-1}$)', labelpad=0.1)
    
        if i == 0:
            ax3.set_title(r'firing rates ')

        ax3.text(0, 1, X,
            horizontalalignment='left',
            verticalalignment='bottom',
            transform=ax3.transAxes)
        
    for loc, spine in ax3.spines.iteritems():
        if loc in ['right', 'top']:
            spine.set_color('none')            
    ax3.xaxis.set_ticks_position('bottom')
    ax3.yaxis.set_ticks_position('left')
    ax3.set_xlabel(r'$t$ (ms)', labelpad=0.1)

      
    ############################################################################
    # C part, plot somatic synapse input currents population resolved 
    ############################################################################
        
    #set up subplots
    nrows = len(meanInpCurrents.keys())
    high = top
    low = bottom
    thickn = (high-low) / nrows - sep
    bottoms = np.linspace(low, high-thickn, nrows)[::-1]

    ax_ = fig.add_axes([lefts[2], bottom, axwidth, top-bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='C')        
    
    for i, Y in enumerate(params.Y):
        value = meanInpCurrents[Y]
        
        tvec = value['tvec']
        inds = (tvec <= T[1]) & (tvec >= T[0])
        ax3 = fig.add_axes([lefts[2], bottoms[i], axwidth, thickn])
        plt.locator_params(nbins=4)

        if i == 0:
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['E'][inds], 10),
                     'k' if analysis_params.bw else analysis_params.colorE, #lw=0.75, #'r',
                     rasterized=False,label='exc.')
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['I'][inds], 10),
                     'gray' if analysis_params.bw else analysis_params.colorI, #lw=0.75, #'b',
                     rasterized=False,label='inh.')
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['E'][inds] + value['I'][inds], 10),
                     'k', lw=1, rasterized=False, label='sum')
        else:
            ax3.plot(tvec[inds][::10], helpers.decimate(value['E'][inds], 10),
                     'k' if analysis_params.bw else analysis_params.colorE, #lw=0.75, #'r',
                     rasterized=False)
            ax3.plot(tvec[inds][::10], helpers.decimate(value['I'][inds], 10),
                     'gray' if analysis_params.bw else analysis_params.colorI, #lw=0.75, #'b',
                     rasterized=False)
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['E'][inds] + value['I'][inds], 10),
                     'k', lw=1, rasterized=False)
        phlp.remove_axis_junk(ax3)

        
        ax3.axis(ax3.axis('tight'))
        ax3.set_yticks([ax3.axis()[2], 0, ax3.axis()[3]])
        ax3.set_yticklabels([np.round((value['I'][inds]).min(), decimals=1),
                             0,
                             np.round((value['E'][inds]).max(), decimals=1)])

        
        ax3.text(0, 1, Y,
            horizontalalignment='left',
            verticalalignment='bottom',
            transform=ax3.transAxes)        
    
        if i == nrows-1:
            ax3.set_xlabel('$t$ (ms)', labelpad=0.1)
        else:
            ax3.set_xticklabels([])
        
        if i == 3:
            ax3.set_ylabel(r'(nA)', labelpad=0.1)
    
        if i == 0:
            ax3.set_title('input currents')
            ax3.legend(loc=1,prop={'size':4})
        phlp.remove_axis_junk(ax3)
        ax3.set_xlim(T)
        


    ############################################################################
    # D part, plot membrane voltage population resolved 
    ############################################################################
        
    nrows = len(meanVoltages.keys())    
    high = top
    low = bottom
    thickn = (high-low) / nrows - sep
    bottoms = np.linspace(low, high-thickn, nrows)[::-1]
    
    colors = phlp.get_colors(len(params.Y)) 

    ax_ = fig.add_axes([lefts[3], bottom, axwidth, top-bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='D')        
    
    for i, Y in enumerate(params.Y):
        value = meanVoltages[Y]
        
        tvec = value['tvec']
        inds = (tvec <= T[1]) & (tvec >= T[0])
        
        ax4 = fig.add_axes([lefts[3], bottoms[i], axwidth, thickn])
        ax4.plot(tvec[inds][::10], helpers.decimate(value['data'][inds], 10), color=colors[i],
                 zorder=0, rasterized=False)
                
        
        phlp.remove_axis_junk(ax4)
        
        plt.locator_params(nbins=4)
        
        ax4.axis(ax4.axis('tight'))
        ax4.yaxis.set_major_locator(plt.MaxNLocator(3))
        
        ax4.text(0, 1, Y,
            horizontalalignment='left',
            verticalalignment='bottom',
            transform=ax4.transAxes)        
    
        if i == nrows-1:
            ax4.set_xlabel('$t$ (ms)', labelpad=0.1)
        else:
            ax4.set_xticklabels([])
        
        if i == 3:
            ax4.set_ylabel(r'(mV)', labelpad=0.1)
    
        if i == 0:
            ax4.set_title('voltages')
Ejemplo n.º 2
0
def plot_multi_scale_output_a(fig):
    #get the mean somatic currents and voltages,
    #write pickles if they do not exist:
    if not os.path.isfile(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanInpCurrents.pickle')):
        meanInpCurrents = getMeanInpCurrents(
            params, params.n_rec_input_spikes,
            os.path.join(params.spike_output_path, 'population_input_spikes'))
        f = open(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanInpCurrents.pickle'), 'wb')
        pickle.dump(meanInpCurrents, f)
        f.close()
    else:
        f = open(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanInpCurrents.pickle'), 'rb')
        meanInpCurrents = pickle.load(f)
        f.close()

    if not os.path.isfile(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanVoltages.pickle')):
        meanVoltages = getMeanVoltages(
            params, params.n_rec_voltage,
            os.path.join(params.spike_output_path, 'voltages'))
        f = open(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanVoltages.pickle'), 'wb')
        pickle.dump(meanVoltages, f)
        f.close()
    else:
        f = open(
            os.path.join(params.savefolder, 'data_analysis',
                         'meanVoltages.pickle'), 'rb')
        meanVoltages = pickle.load(f)
        f.close()

    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    show_ax_labels = True
    show_insets = False

    transient = 200
    T = [800, 1000]
    T_inset = [900, 920]

    sep = 0.025 / 2  #0.017

    left = 0.075
    bottom = 0.55
    top = 0.975
    right = 0.95
    axwidth = 0.16
    numcols = 4
    insetwidth = axwidth / 2
    insetheight = 0.5

    lefts = np.linspace(left, right - axwidth, numcols)

    #fig = plt.figure()
    ############################################################################
    # A part, plot spike rasters
    ############################################################################
    ax1 = fig.add_axes([lefts[0], bottom, axwidth, top - bottom])
    #fig.text(0.005,0.95,'a',fontsize=8, fontweight='demibold')
    if show_ax_labels:
        phlp.annotate_subplot(
            ax1,
            ncols=4,
            nrows=1.02,
            letter='A',
        )
    ax1.set_title('network activity')
    plt.locator_params(nbins=4)

    x, y = networkSim.get_xy(T, fraction=1)
    networkSim.plot_raster(ax1,
                           T,
                           x,
                           y,
                           markersize=0.2,
                           marker='_',
                           alpha=1.,
                           legend=False,
                           pop_names=True,
                           rasterized=False)
    phlp.remove_axis_junk(ax1)
    ax1.set_xlabel(r'$t$ (ms)', labelpad=0.1)
    ax1.set_ylabel('population', labelpad=0.1)

    # Inset
    if show_insets:
        ax2 = fig.add_axes([
            lefts[0] + axwidth - insetwidth, top - insetheight, insetwidth,
            insetheight
        ])
        plt.locator_params(nbins=4)
        x, y = networkSim.get_xy(T_inset, fraction=0.4)
        networkSim.plot_raster(ax2,
                               T_inset,
                               x,
                               y,
                               markersize=0.25,
                               alpha=1.,
                               legend=False)
        phlp.remove_axis_junk(ax2)
        ax2.set_xticks(T_inset)
        ax2.set_yticks([])
        ax2.set_yticklabels([])
        ax2.set_ylabel('')
        ax2.set_xlabel('')

    ############################################################################
    # B part, plot firing rates
    ############################################################################

    nrows = len(networkSim.X) - 1
    high = top
    low = bottom
    thickn = (high - low) / nrows - sep
    bottoms = np.linspace(low, high - thickn, nrows)[::-1]

    x, y = networkSim.get_xy(T, fraction=1)

    #dummy ax to put label in correct location
    ax_ = fig.add_axes([lefts[1], bottom, axwidth, top - bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='B')

    for i, X in enumerate(networkSim.X[:-1]):
        ax3 = fig.add_axes([lefts[1], bottoms[i], axwidth, thickn])
        plt.locator_params(nbins=4)
        phlp.remove_axis_junk(ax3)
        networkSim.plot_f_rate(ax3,
                               X,
                               i,
                               T,
                               x,
                               y,
                               yscale='linear',
                               plottype='fill_between',
                               show_label=False,
                               rasterized=False)
        ax3.yaxis.set_major_locator(plt.MaxNLocator(3))
        if i != nrows - 1:
            ax3.set_xticklabels([])

        if i == 3:
            ax3.set_ylabel(r'(s$^{-1}$)', labelpad=0.1)

        if i == 0:
            ax3.set_title(r'firing rates ')

        ax3.text(0,
                 1,
                 X,
                 horizontalalignment='left',
                 verticalalignment='bottom',
                 transform=ax3.transAxes)

    for loc, spine in ax3.spines.items():
        if loc in ['right', 'top']:
            spine.set_color('none')
    ax3.xaxis.set_ticks_position('bottom')
    ax3.yaxis.set_ticks_position('left')
    ax3.set_xlabel(r'$t$ (ms)', labelpad=0.1)

    ############################################################################
    # C part, plot somatic synapse input currents population resolved
    ############################################################################

    #set up subplots
    nrows = len(list(meanInpCurrents.keys()))
    high = top
    low = bottom
    thickn = (high - low) / nrows - sep
    bottoms = np.linspace(low, high - thickn, nrows)[::-1]

    ax_ = fig.add_axes([lefts[2], bottom, axwidth, top - bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='C')

    for i, Y in enumerate(params.Y):
        value = meanInpCurrents[Y]

        tvec = value['tvec']
        inds = (tvec <= T[1]) & (tvec >= T[0])
        ax3 = fig.add_axes([lefts[2], bottoms[i], axwidth, thickn])
        plt.locator_params(nbins=4)

        if i == 0:
            ax3.plot(
                tvec[inds][::10],
                helpers.decimate(value['E'][inds], 10),
                'k' if analysis_params.bw else
                analysis_params.colorE,  #lw=0.75, #'r',
                rasterized=False,
                label='exc.')
            ax3.plot(
                tvec[inds][::10],
                helpers.decimate(value['I'][inds], 10),
                'gray' if analysis_params.bw else
                analysis_params.colorI,  #lw=0.75, #'b',
                rasterized=False,
                label='inh.')
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['E'][inds] + value['I'][inds], 10),
                     'k',
                     lw=1,
                     rasterized=False,
                     label='sum')
        else:
            ax3.plot(
                tvec[inds][::10],
                helpers.decimate(value['E'][inds], 10),
                'k' if analysis_params.bw else
                analysis_params.colorE,  #lw=0.75, #'r',
                rasterized=False)
            ax3.plot(
                tvec[inds][::10],
                helpers.decimate(value['I'][inds], 10),
                'gray' if analysis_params.bw else
                analysis_params.colorI,  #lw=0.75, #'b',
                rasterized=False)
            ax3.plot(tvec[inds][::10],
                     helpers.decimate(value['E'][inds] + value['I'][inds], 10),
                     'k',
                     lw=1,
                     rasterized=False)
        phlp.remove_axis_junk(ax3)

        ax3.axis(ax3.axis('tight'))
        ax3.set_yticks([ax3.axis()[2], 0, ax3.axis()[3]])
        ax3.set_yticklabels([
            np.round((value['I'][inds]).min(), decimals=1), 0,
            np.round((value['E'][inds]).max(), decimals=1)
        ])

        ax3.text(0,
                 1,
                 Y,
                 horizontalalignment='left',
                 verticalalignment='bottom',
                 transform=ax3.transAxes)

        if i == nrows - 1:
            ax3.set_xlabel('$t$ (ms)', labelpad=0.1)
        else:
            ax3.set_xticklabels([])

        if i == 3:
            ax3.set_ylabel(r'(nA)', labelpad=0.1)

        if i == 0:
            ax3.set_title('input currents')
            ax3.legend(loc=1, prop={'size': 4})
        phlp.remove_axis_junk(ax3)
        ax3.set_xlim(T)

    ############################################################################
    # D part, plot membrane voltage population resolved
    ############################################################################

    nrows = len(list(meanVoltages.keys()))
    high = top
    low = bottom
    thickn = (high - low) / nrows - sep
    bottoms = np.linspace(low, high - thickn, nrows)[::-1]

    colors = phlp.get_colors(len(params.Y))

    ax_ = fig.add_axes([lefts[3], bottom, axwidth, top - bottom])
    ax_.axis('off')
    if show_ax_labels:
        phlp.annotate_subplot(ax_, ncols=4, nrows=1, letter='D')

    for i, Y in enumerate(params.Y):
        value = meanVoltages[Y]

        tvec = value['tvec']
        inds = (tvec <= T[1]) & (tvec >= T[0])

        ax4 = fig.add_axes([lefts[3], bottoms[i], axwidth, thickn])
        ax4.plot(tvec[inds][::10],
                 helpers.decimate(value['data'][inds], 10),
                 color=colors[i],
                 zorder=0,
                 rasterized=False)

        phlp.remove_axis_junk(ax4)

        plt.locator_params(nbins=4)

        ax4.axis(ax4.axis('tight'))
        ax4.yaxis.set_major_locator(plt.MaxNLocator(3))

        ax4.text(0,
                 1,
                 Y,
                 horizontalalignment='left',
                 verticalalignment='bottom',
                 transform=ax4.transAxes)

        if i == nrows - 1:
            ax4.set_xlabel('$t$ (ms)', labelpad=0.1)
        else:
            ax4.set_xticklabels([])

        if i == 3:
            ax4.set_ylabel(r'(mV)', labelpad=0.1)

        if i == 0:
            ax4.set_title('voltages')

        ax4.set_xlim(T)
Ejemplo n.º 3
0
if __name__ == '__main__':

    params = multicompartment_params()
    ana_params = analysis_params.params()
    ana_params.set_PLOS_2column_fig_style(ratio=0.5)

    params.figures_path = os.path.join(params.savefolder, 'figures')
    params.spike_output_path = os.path.join(params.savefolder,
                                            'processed_nest_output')
    params.networkSimParams['spike_output_path'] = params.spike_output_path

    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    transient = 200
    T = [890, 920]

    show_ax_labels = True
    show_images = True
    # show_images = False if analysis_params.bw else True

    gs = gridspec.GridSpec(9, 4)

    fig = plt.figure()
    fig.subplots_adjust(left=0.06,
                        right=0.94,
                        bottom=0.075,
                        top=0.925,
Ejemplo n.º 4
0
def fig_network_input_structure(fig,
                                params,
                                bottom=0.1,
                                top=0.9,
                                transient=200,
                                T=[800, 1000],
                                Df=0.,
                                mlab=True,
                                NFFT=256,
                                srate=1000,
                                window=plt.mlab.window_hanning,
                                noverlap=256 * 3 / 4,
                                letters='abcde',
                                flim=(4, 400),
                                show_titles=True,
                                show_xlabels=True,
                                show_CSD=False):
    '''
    This figure is the top part for plotting a comparison between the PD-model
    and the modified-PD model
    
    '''
    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    # ana_params.set_PLOS_2column_fig_style(ratio=ratio)
    # fig = plt.figure()
    # fig.subplots_adjust(left=0.06, right=0.94, bottom=0.09, top=0.92, wspace=0.5, hspace=0.2)

    #use gridspec to get nicely aligned subplots througout panel
    gs1 = gridspec.GridSpec(5, 5, bottom=bottom, top=top)

    ############################################################################
    # A part, full dot display
    ############################################################################

    ax0 = fig.add_subplot(gs1[:, 0])
    phlp.remove_axis_junk(ax0)
    phlp.annotate_subplot(ax0,
                          ncols=5,
                          nrows=1,
                          letter=letters[0],
                          linear_offset=0.065)

    x, y = networkSim.get_xy(T, fraction=1)
    networkSim.plot_raster(ax0,
                           T,
                           x,
                           y,
                           markersize=0.2,
                           marker='_',
                           alpha=1.,
                           legend=False,
                           pop_names=True,
                           rasterized=False)
    ax0.set_ylabel('population', labelpad=0.)
    ax0.set_xticks([800, 900, 1000])

    if show_titles:
        ax0.set_title('spiking activity', va='center')
    if show_xlabels:
        ax0.set_xlabel(r'$t$ (ms)', labelpad=0.)
    else:
        ax0.set_xlabel('')

    ############################################################################
    # B part, firing rate spectra
    ############################################################################

    # Get the firing rate from Potjan Diesmann et al network activity
    #collect the spikes x is the times, y is the id of the cell.
    T_all = [transient, networkSim.simtime]
    bins = np.arange(transient, networkSim.simtime + 1)

    x, y = networkSim.get_xy(T_all, fraction=1)

    # create invisible axes to position labels correctly
    ax_ = fig.add_subplot(gs1[:, 1])
    phlp.annotate_subplot(ax_,
                          ncols=5,
                          nrows=1,
                          letter=letters[1],
                          linear_offset=0.065)
    if show_titles:
        ax_.set_title('firing rate PSD', va='center')

    ax_.axis('off')

    colors = phlp.get_colors(len(params.Y)) + ['k']

    COUNTER = 0
    label_set = False

    t**s = ['L23E/I', 'L4E/I', 'L5E/I', 'L6E/I', 'TC']

    if x['TC'].size > 0:
        TC = True
    else:
        TC = False

    BAxes = []
    for i, X in enumerate(networkSim.X):

        if i % 2 == 0:
            ax1 = fig.add_subplot(gs1[COUNTER, 1])
            phlp.remove_axis_junk(ax1)

            if x[X].size > 0:
                ax1.text(0.05,
                         0.85,
                         t**s[COUNTER],
                         horizontalalignment='left',
                         verticalalignment='bottom',
                         transform=ax1.transAxes)
            BAxes.append(ax1)

        #firing rate histogram
        hist = np.histogram(x[X], bins=bins)[0].astype(float)
        hist -= hist.mean()

        if mlab:
            Pxx, freqs = plt.mlab.psd(hist,
                                      NFFT=NFFT,
                                      Fs=srate,
                                      noverlap=noverlap,
                                      window=window)
        else:
            [freqs, Pxx] = hlp.powerspec([hist],
                                         tbin=1.,
                                         Df=Df,
                                         pointProcess=False)
            mask = np.where(freqs >= 0.)
            freqs = freqs[mask]
            Pxx = Pxx.flatten()
            Pxx = Pxx[mask]
            Pxx = Pxx / (T_all[1] - T_all[0])**2

        if x[X].size > 0:
            ax1.loglog(freqs[1:],
                       Pxx[1:],
                       label=X,
                       color=colors[i],
                       clip_on=True)
            ax1.axis(ax1.axis('tight'))
            ax1.set_ylim([5E-4, 5E2])
            ax1.set_yticks([1E-3, 1E-1, 1E1])
            if label_set == False:
                ax1.set_ylabel(r'(s$^{-2}$/Hz)', labelpad=0.)
                label_set = True
            if i > 1:
                ax1.set_yticklabels([])
            if i >= 6 and not TC and show_xlabels or X == 'TC' and TC and show_xlabels:
                ax1.set_xlabel('$f$ (Hz)', labelpad=0.)
            if TC and i < 8 or not TC and i < 6:
                ax1.set_xticklabels([])

        else:
            ax1.axis('off')

        ax1.set_xlim(flim)

        if i % 2 == 0:
            COUNTER += 1

        ax1.yaxis.set_minor_locator(plt.NullLocator())

    ############################################################################
    # c part, LFP traces and CSD color plots
    ############################################################################

    ax2 = fig.add_subplot(gs1[:, 2])

    phlp.annotate_subplot(ax2,
                          ncols=5,
                          nrows=1,
                          letter=letters[2],
                          linear_offset=0.065)

    phlp.remove_axis_junk(ax2)
    plot_signal_sum(ax2,
                    params,
                    fname=os.path.join(params.savefolder, 'LFPsum.h5'),
                    unit='mV',
                    T=T,
                    ylim=[-1600, 40],
                    rasterized=False)

    # CSD background colorplot
    if show_CSD:
        im = plot_signal_sum_colorplot(ax2,
                                       params,
                                       os.path.join(params.savefolder,
                                                    'CSDsum.h5'),
                                       unit=r'($\mu$Amm$^{-3}$)',
                                       T=[800, 1000],
                                       colorbar=False,
                                       ylim=[-1600, 40],
                                       fancy=False,
                                       cmap=plt.cm.get_cmap('bwr_r', 21),
                                       rasterized=False)
        cb = phlp.colorbar(fig,
                           ax2,
                           im,
                           width=0.05,
                           height=0.4,
                           hoffset=-0.05,
                           voffset=0.3)
        cb.set_label('($\mu$Amm$^{-3}$)', labelpad=0.1)

    ax2.set_xticks([800, 900, 1000])
    ax2.axis(ax2.axis('tight'))

    if show_titles:
        if show_CSD:
            ax2.set_title('LFP & CSD', va='center')
        else:
            ax2.set_title('LFP', va='center')
    if show_xlabels:
        ax2.set_xlabel(r'$t$ (ms)', labelpad=0.)
    else:
        ax2.set_xlabel('')

    ############################################################################
    # d part, LFP power trace for each layer
    ############################################################################

    freqs, PSD = calc_signal_power(params,
                                   fname=os.path.join(params.savefolder,
                                                      'LFPsum.h5'),
                                   transient=transient,
                                   Df=Df,
                                   mlab=mlab,
                                   NFFT=NFFT,
                                   noverlap=noverlap,
                                   window=window)

    channels = [0, 3, 7, 11, 13]

    # create invisible axes to position labels correctly
    ax_ = fig.add_subplot(gs1[:, 3])
    phlp.annotate_subplot(ax_,
                          ncols=5,
                          nrows=1,
                          letter=letters[3],
                          linear_offset=0.065)

    if show_titles:
        ax_.set_title('LFP PSD', va='center')

    ax_.axis('off')

    for i, ch in enumerate(channels):

        ax = fig.add_subplot(gs1[i, 3])
        phlp.remove_axis_junk(ax)

        if i == 0:
            ax.set_ylabel('(mV$^2$/Hz)', labelpad=0)

        ax.loglog(freqs[1:], PSD[ch][1:], color='k')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        if i < 4:
            ax.set_xticklabels([])
        ax.text(0.75,
                0.85,
                'ch. %i' % (channels[i] + 1),
                horizontalalignment='left',
                verticalalignment='bottom',
                fontsize=6,
                transform=ax.transAxes)
        ax.tick_params(axis='y', which='minor', bottom='off')
        ax.axis(ax.axis('tight'))
        ax.yaxis.set_minor_locator(plt.NullLocator())

        ax.set_xlim(flim)
        ax.set_ylim(1E-7, 2E-4)
        if i != 0:
            ax.set_yticklabels([])

    if show_xlabels:
        ax.set_xlabel('$f$ (Hz)', labelpad=0.)

    ############################################################################
    # e part signal power
    ############################################################################

    ax4 = fig.add_subplot(gs1[:, 4])

    phlp.annotate_subplot(ax4,
                          ncols=5,
                          nrows=1,
                          letter=letters[4],
                          linear_offset=0.065)

    fname = os.path.join(params.savefolder, 'LFPsum.h5')
    im = plot_signal_power_colorplot(ax4,
                                     params,
                                     fname=fname,
                                     transient=transient,
                                     Df=Df,
                                     mlab=mlab,
                                     NFFT=NFFT,
                                     window=window,
                                     cmap=plt.cm.get_cmap('gray_r', 12),
                                     vmin=1E-7,
                                     vmax=1E-4)
    phlp.remove_axis_junk(ax4)

    ax4.set_xlim(flim)

    cb = phlp.colorbar(fig,
                       ax4,
                       im,
                       width=0.05,
                       height=0.5,
                       hoffset=-0.05,
                       voffset=0.5)
    cb.set_label('(mV$^2$/Hz)', labelpad=0.1)

    if show_titles:
        ax4.set_title('LFP PSD', va='center')
    if show_xlabels:
        ax4.set_xlabel(r'$f$ (Hz)', labelpad=0.)
    else:
        ax4.set_xlabel('')

    return fig
Ejemplo n.º 5
0
def fig_network_input_structure(fig, params, bottom=0.1, top=0.9, transient=200, T=[800, 1000], Df= 0., mlab= True, NFFT=256, srate=1000,
             window=plt.mlab.window_hanning, noverlap=256*3/4, letters='abcde', flim=(4, 400),
             show_titles=True, show_xlabels=True, show_CSD=False):
    '''
    This figure is the top part for plotting a comparison between the PD-model
    and the modified-PD model
    
    '''
    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))


    # ana_params.set_PLOS_2column_fig_style(ratio=ratio)
    # fig = plt.figure()
    # fig.subplots_adjust(left=0.06, right=0.94, bottom=0.09, top=0.92, wspace=0.5, hspace=0.2)
    
    #use gridspec to get nicely aligned subplots througout panel
    gs1 = gridspec.GridSpec(5, 5, bottom=bottom, top=top)
    
    
    ############################################################################ 
    # A part, full dot display
    ############################################################################
    
    ax0 = fig.add_subplot(gs1[:, 0])
    phlp.remove_axis_junk(ax0)
    phlp.annotate_subplot(ax0, ncols=5, nrows=1, letter=letters[0],
                     linear_offset=0.065)
   
    x, y = networkSim.get_xy(T, fraction=1)
    networkSim.plot_raster(ax0, T, x, y,
                           markersize=0.2, marker='_',
                           alpha=1.,
                           legend=False, pop_names=True,
                           rasterized=False)
    ax0.set_ylabel('population', labelpad=0.)
    ax0.set_xticks([800,900,1000])
   
    if show_titles:
        ax0.set_title('spiking activity',va='center')
    if show_xlabels:
        ax0.set_xlabel(r'$t$ (ms)', labelpad=0.)
    else:
        ax0.set_xlabel('')
      
    ############################################################################
    # B part, firing rate spectra
    ############################################################################

  
    # Get the firing rate from Potjan Diesmann et al network activity
    #collect the spikes x is the times, y is the id of the cell.
    T_all=[transient, networkSim.simtime]
    bins = np.arange(transient, networkSim.simtime+1)
        
    x, y = networkSim.get_xy(T_all, fraction=1)

    # create invisible axes to position labels correctly
    ax_ = fig.add_subplot(gs1[:, 1])
    phlp.annotate_subplot(ax_, ncols=5, nrows=1, letter=letters[1],
                                 linear_offset=0.065)
    if show_titles:
        ax_.set_title('firing rate PSD', va='center')
    
    ax_.axis('off')

    colors = phlp.get_colors(len(params.Y))+['k']
    
    COUNTER = 0
    label_set = False
    
    t**s = ['L23E/I', 'L4E/I', 'L5E/I', 'L6E/I', 'TC']

    if x['TC'].size > 0:
        TC = True
    else:
        TC = False

    BAxes = []
    for i, X in enumerate(networkSim.X):

        if i % 2 == 0:
            ax1 = fig.add_subplot(gs1[COUNTER, 1])
            phlp.remove_axis_junk(ax1)

            if x[X].size > 0:
                ax1.text(0.05, 0.85, t**s[COUNTER],
                    horizontalalignment='left',
                    verticalalignment='bottom',
                    transform=ax1.transAxes)
            BAxes.append(ax1)


        #firing rate histogram
        hist = np.histogram(x[X], bins=bins)[0].astype(float)
        hist -= hist.mean()
        
        if mlab:
            Pxx, freqs=plt.mlab.psd(hist, NFFT=NFFT,
                                    Fs=srate, noverlap=noverlap, window=window)
        else:
            [freqs, Pxx] = hlp.powerspec([hist], tbin= 1.,
                                        Df=Df, pointProcess=False)
            mask = np.where(freqs >= 0.)
            freqs = freqs[mask]
            Pxx = Pxx.flatten()
            Pxx = Pxx[mask]
            Pxx = Pxx/(T_all[1]-T_all[0])**2
        
        if x[X].size > 0:
            ax1.loglog(freqs[1:], Pxx[1:],
                       label=X, color=colors[i],
                       clip_on=True)
            ax1.axis(ax1.axis('tight'))
            ax1.set_ylim([5E-4,5E2])
            ax1.set_yticks([1E-3,1E-1,1E1])
            if label_set == False:
                ax1.set_ylabel(r'(s$^{-2}$/Hz)', labelpad=0.)
                label_set = True
            if i > 1:
                ax1.set_yticklabels([])
            if i >= 6 and not TC and show_xlabels or X == 'TC' and TC and show_xlabels:
                ax1.set_xlabel('$f$ (Hz)', labelpad=0.)
            if TC and i < 8 or not TC and i < 6:
                ax1.set_xticklabels([])    

        else:
            ax1.axis('off')
                       
        ax1.set_xlim(flim)
           
        
        if i % 2 == 0:
            COUNTER += 1
        
        ax1.yaxis.set_minor_locator(plt.NullLocator())
        



    ############################################################################
    # c part, LFP traces and CSD color plots
    ############################################################################
   
    ax2 = fig.add_subplot(gs1[:, 2])
    
    phlp.annotate_subplot(ax2, ncols=5, nrows=1, letter=letters[2],
                     linear_offset=0.065)


    phlp.remove_axis_junk(ax2)
    plot_signal_sum(ax2, params,
                    fname=os.path.join(params.savefolder, 'LFPsum.h5'),
                    unit='mV', T=T, ylim=[-1600, 40],
                    rasterized=False)
    
    # CSD background colorplot
    if show_CSD:
        im = plot_signal_sum_colorplot(ax2, params, os.path.join(params.savefolder, 'CSDsum.h5'),
                                  unit=r'($\mu$Amm$^{-3}$)', T=[800, 1000],
                                  colorbar=False,
                                  ylim=[-1600, 40], fancy=False, cmap=plt.cm.get_cmap('bwr_r', 21),
                                  rasterized=False)
        cb = phlp.colorbar(fig, ax2, im,
                           width=0.05, height=0.4,
                           hoffset=-0.05, voffset=0.3)
        cb.set_label('($\mu$Amm$^{-3}$)', labelpad=0.1)

    ax2.set_xticks([800,900,1000])
    ax2.axis(ax2.axis('tight'))
     
    if show_titles:
        if show_CSD:
            ax2.set_title('LFP & CSD', va='center')
        else:
            ax2.set_title('LFP', va='center')
    if show_xlabels:
        ax2.set_xlabel(r'$t$ (ms)', labelpad=0.)
    else:
        ax2.set_xlabel('')
  
 
    ############################################################################
    # d part, LFP power trace for each layer
    ############################################################################

    freqs, PSD = calc_signal_power(params, fname=os.path.join(params.savefolder,
                                                           'LFPsum.h5'),
                                        transient=transient, Df=Df, mlab=mlab,
                                        NFFT=NFFT, noverlap=noverlap,
                                        window=window)

    channels = [0, 3, 7, 11, 13]
  
    # create invisible axes to position labels correctly
    ax_ = fig.add_subplot(gs1[:, 3])
    phlp.annotate_subplot(ax_, ncols=5, nrows=1, letter=letters[3],
                                 linear_offset=0.065)

    if show_titles:
        ax_.set_title('LFP PSD',va='center')
    
    ax_.axis('off')

    for i, ch in enumerate(channels):

        ax = fig.add_subplot(gs1[i, 3])
        phlp.remove_axis_junk(ax)

        if i == 0:
            ax.set_ylabel('(mV$^2$/Hz)', labelpad=0)

        ax.loglog(freqs[1:],PSD[ch][1:], color='k')
        ax.xaxis.set_ticks_position('bottom')
        ax.yaxis.set_ticks_position('left')
        if i < 4:
            ax.set_xticklabels([])
        ax.text(0.75, 0.85,'ch. %i' %(channels[i]+1),
                horizontalalignment='left',
                verticalalignment='bottom',
                fontsize=6,
                transform=ax.transAxes)
        ax.tick_params(axis='y', which='minor', bottom='off')
        ax.axis(ax.axis('tight'))
        ax.yaxis.set_minor_locator(plt.NullLocator())

        ax.set_xlim(flim)
        ax.set_ylim(1E-7,2E-4)
        if i != 0 :
            ax.set_yticklabels([])


    if show_xlabels:
        ax.set_xlabel('$f$ (Hz)', labelpad=0.)
    
    ############################################################################
    # e part signal power
    ############################################################################
    
    ax4 = fig.add_subplot(gs1[:, 4])

    phlp.annotate_subplot(ax4, ncols=5, nrows=1, letter=letters[4],
                     linear_offset=0.065)
  
    fname=os.path.join(params.savefolder, 'LFPsum.h5')
    im = plot_signal_power_colorplot(ax4, params, fname=fname, transient=transient, Df=Df,
                                mlab=mlab, NFFT=NFFT, window=window,
                                cmap=plt.cm.get_cmap('gray_r', 12),
                                vmin=1E-7, vmax=1E-4)
    phlp.remove_axis_junk(ax4)

    ax4.set_xlim(flim)

    cb = phlp.colorbar(fig, ax4, im,
                       width=0.05, height=0.5,
                       hoffset=-0.05, voffset=0.5)
    cb.set_label('(mV$^2$/Hz)', labelpad=0.1)


    if show_titles:
        ax4.set_title('LFP PSD', va='center')
    if show_xlabels:
        ax4.set_xlabel(r'$f$ (Hz)', labelpad=0.)
    else:
        ax4.set_xlabel('')
       
    return fig 
Ejemplo n.º 6
0
def fig_intro(params,
              ana_params,
              T=[800, 1000],
              fraction=0.05,
              rasterized=False):
    '''set up plot for introduction'''
    ana_params.set_PLOS_2column_fig_style(ratio=0.5)

    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    #set up figure and subplots
    fig = plt.figure()
    gs = gridspec.GridSpec(3, 4)

    fig.subplots_adjust(left=0.05, right=0.95, wspace=0.5, hspace=0.)

    #network diagram
    ax0_1 = fig.add_subplot(gs[:, 0], frameon=False)
    ax0_1.set_title('point-neuron network', va='bottom')

    network_sketch(ax0_1, yscaling=1.3)
    ax0_1.xaxis.set_ticks([])
    ax0_1.yaxis.set_ticks([])
    phlp.annotate_subplot(ax0_1,
                          ncols=4,
                          nrows=1,
                          letter='A',
                          linear_offset=0.065)

    #network raster
    ax1 = fig.add_subplot(gs[:, 1], frameon=True)
    phlp.remove_axis_junk(ax1)
    phlp.annotate_subplot(ax1,
                          ncols=4,
                          nrows=1,
                          letter='B',
                          linear_offset=0.065)

    x, y = networkSim.get_xy(T, fraction=fraction)
    # networkSim.plot_raster(ax1, T, x, y, markersize=0.1, alpha=1.,legend=False, pop_names=True)
    networkSim.plot_raster(ax1,
                           T,
                           x,
                           y,
                           markersize=0.2,
                           marker='_',
                           alpha=1.,
                           legend=False,
                           pop_names=True,
                           rasterized=rasterized)
    ax1.set_ylabel('')
    ax1.xaxis.set_major_locator(plt.MaxNLocator(4))
    ax1.set_title('spiking activity', va='bottom')
    a = ax1.axis()
    ax1.vlines(x['TC'][0], a[2], a[3], 'k', lw=0.25)

    #population
    ax2 = fig.add_subplot(gs[:, 2], frameon=False)
    ax2.xaxis.set_ticks([])
    ax2.yaxis.set_ticks([])
    plot_population(ax2,
                    params,
                    isometricangle=np.pi / 24,
                    plot_somas=False,
                    plot_morphos=True,
                    num_unitsE=1,
                    num_unitsI=1,
                    clip_dendrites=True,
                    main_pops=True,
                    title='',
                    rasterized=rasterized)
    ax2.set_title('multicompartment\nneurons',
                  va='bottom',
                  fontweight='normal')
    phlp.annotate_subplot(ax2,
                          ncols=4,
                          nrows=1,
                          letter='C',
                          linear_offset=0.065)

    #LFP traces in all channels
    ax3 = fig.add_subplot(gs[:, 3], frameon=True)
    phlp.remove_axis_junk(ax3)
    plot_signal_sum(ax3,
                    params,
                    fname=os.path.join(params.savefolder, 'LFPsum.h5'),
                    unit='mV',
                    vlimround=0.8,
                    T=T,
                    ylim=[ax2.axis()[2], ax2.axis()[3]],
                    rasterized=False)
    ax3.set_title('LFP', va='bottom')
    ax3.xaxis.set_major_locator(plt.MaxNLocator(4))
    phlp.annotate_subplot(ax3,
                          ncols=4,
                          nrows=1,
                          letter='D',
                          linear_offset=0.065)
    a = ax3.axis()
    ax3.vlines(x['TC'][0], a[2], a[3], 'k', lw=0.25)

    #draw some arrows:
    ax = plt.gca()
    ax.annotate(
        "",
        xy=(0.27, 0.5),
        xytext=(.24, 0.5),
        xycoords="figure fraction",
        arrowprops=dict(facecolor='black', arrowstyle='simple'),
    )
    ax.annotate(
        "",
        xy=(0.52, 0.5),
        xytext=(.49, 0.5),
        xycoords="figure fraction",
        arrowprops=dict(facecolor='black', arrowstyle='simple'),
    )
    ax.annotate(
        "",
        xy=(0.78, 0.5),
        xytext=(.75, 0.5),
        xycoords="figure fraction",
        arrowprops=dict(facecolor='black', arrowstyle='simple'),
    )

    return fig
Ejemplo n.º 7
0
if __name__ == '__main__':

    params = multicompartment_params()
    ana_params = analysis_params.params()
    ana_params.set_PLOS_2column_fig_style(ratio=0.5)

    params.figures_path = os.path.join(params.savefolder, 'figures')
    params.spike_output_path = os.path.join(params.savefolder,
                                                       'processed_nest_output')
    params.networkSimParams['spike_output_path'] = params.spike_output_path

    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))




    transient=200
    T=[890, 920]
    
    show_ax_labels = True
    show_images = True
    # show_images = False if analysis_params.bw else True
    
    gs = gridspec.GridSpec(9,4)
    
    fig = plt.figure()
    fig.subplots_adjust(left=0.06, right=0.94, bottom=0.075, top=0.925, hspace=0.35, wspace=0.35)
Ejemplo n.º 8
0
def fig_intro(params, ana_params, T=[800, 1000], fraction=0.05, rasterized=False):
    '''set up plot for introduction'''
    ana_params.set_PLOS_2column_fig_style(ratio=0.5)
    
    #load spike as database
    networkSim = CachedNetwork(**params.networkSimParams)
    if analysis_params.bw:
        networkSim.colors = phlp.get_colors(len(networkSim.X))

    #set up figure and subplots
    fig = plt.figure()
    gs = gridspec.GridSpec(3, 4)
    
    
    fig.subplots_adjust(left=0.05, right=0.95, wspace=0.5, hspace=0.)


    #network diagram
    ax0_1 = fig.add_subplot(gs[:, 0], frameon=False)
    ax0_1.set_title('point-neuron network', va='bottom')

    network_sketch(ax0_1, yscaling=1.3)
    ax0_1.xaxis.set_ticks([])
    ax0_1.yaxis.set_ticks([])
    phlp.annotate_subplot(ax0_1, ncols=4, nrows=1, letter='A', linear_offset=0.065)
   
    
    #network raster
    ax1 = fig.add_subplot(gs[:, 1], frameon=True)
    phlp.remove_axis_junk(ax1)
    phlp.annotate_subplot(ax1, ncols=4, nrows=1, letter='B', linear_offset=0.065)
       
    x, y = networkSim.get_xy(T, fraction=fraction)
    # networkSim.plot_raster(ax1, T, x, y, markersize=0.1, alpha=1.,legend=False, pop_names=True)
    networkSim.plot_raster(ax1, T, x, y, markersize=0.2, marker='_', alpha=1.,legend=False, pop_names=True, rasterized=rasterized)
    ax1.set_ylabel('')
    ax1.xaxis.set_major_locator(plt.MaxNLocator(4))
    ax1.set_title('spiking activity', va='bottom')
    a = ax1.axis()
    ax1.vlines(x['TC'][0], a[2], a[3], 'k', lw=0.25)


    #population
    ax2 = fig.add_subplot(gs[:, 2], frameon=False)
    ax2.xaxis.set_ticks([])
    ax2.yaxis.set_ticks([])
    plot_population(ax2, params, isometricangle=np.pi/24, plot_somas=False,
                    plot_morphos=True, num_unitsE=1, num_unitsI=1,
                    clip_dendrites=True, main_pops=True, title='',
                    rasterized=rasterized)
    ax2.set_title('multicompartment\nneurons', va='bottom', fontweight='normal')
    phlp.annotate_subplot(ax2, ncols=4, nrows=1, letter='C', linear_offset=0.065)
    

    #LFP traces in all channels
    ax3 = fig.add_subplot(gs[:, 3], frameon=True)
    phlp.remove_axis_junk(ax3)
    plot_signal_sum(ax3, params, fname=os.path.join(params.savefolder, 'LFPsum.h5'),
                unit='mV', vlimround=0.8,
                T=T, ylim=[ax2.axis()[2], ax2.axis()[3]],
                rasterized=False)
    ax3.set_title('LFP', va='bottom')
    ax3.xaxis.set_major_locator(plt.MaxNLocator(4))
    phlp.annotate_subplot(ax3, ncols=4, nrows=1, letter='D', linear_offset=0.065)
    a = ax3.axis()
    ax3.vlines(x['TC'][0], a[2], a[3], 'k', lw=0.25)
    
    
    #draw some arrows:
    ax = plt.gca()
    ax.annotate("", xy=(0.27, 0.5), xytext=(.24, 0.5),
                xycoords="figure fraction",
            arrowprops=dict(facecolor='black', arrowstyle='simple'),
            )
    ax.annotate("", xy=(0.52, 0.5), xytext=(.49, 0.5),
                xycoords="figure fraction",
            arrowprops=dict(facecolor='black', arrowstyle='simple'),
            )
    ax.annotate("", xy=(0.78, 0.5), xytext=(.75, 0.5),
                xycoords="figure fraction",
            arrowprops=dict(facecolor='black', arrowstyle='simple'),
            )

    
    return fig