Example #1
0
def plot_kc_spikes(ax, fd, ca_side='both', color='k', marker=','):
    """Raster plot KC spike times for KCs belonging to the specified side
    of calyx ('lca' or 'mca').

    This function does not care about spatial clusters.

    Returns: the line object, list of spike times and list of their
    y-positions.

    """
    if ca_side == 'both':
        nodes = fd[nda.kc_st_path].keys()
    else:
        nodes = nda.get_kc_spike_nodes_by_region(fd, ca_side)
    spike_x, spike_y = [], []
    fname = fd.filename
    try:
        spike_x, spike_y = nda.get_event_times(
            fd['/data/event/kc/kc_spiketime'], nodes=nodes)
    except KeyError:
        dirname = os.path.dirname(fname)
        fname = 'kc_spikes_' + os.path.basename(fd.filename)
        with h5.File(os.path.join(dirname, fname)) as kc_file:
            spike_x, spike_y = nda.get_event_times(kc_file, nodes=nodes)
    if len(spike_x) > 0:
        ret = ax.plot(np.concatenate(spike_x),
                      np.concatenate(spike_y),
                      color=color,
                      marker=marker,
                      linestyle='none')
    else:
        ret = None
    return ret, spike_x, spike_y
def main():
    plt.close('all')
    fig_ax = []
    fig_ax_hist = []
    for ii in range(len(SIX_SPIKES)):
        fig, ax = plt.subplots(nrows=4, ncols=3, sharex='row', sharey='row')
        fig_ax.append((fig, ax))    
        fig_hist, ax_hist = plt.subplots()
        fig_ax_hist.append((fig_hist, ax_hist))
        for jj, sim_set in enumerate([SIX_SPIKES, THREE_SPIKES, ALL_SPIKES]):
            for kk, jid in enumerate(sim_set[ii][:-1]):            
                try:
                    fname = nda.find_h5_file(jid, DATA_DIR)
                except:
                    # First entry is old data and moved to back up
                    # Template should still have all the data
                    print(f'JID {jid} not in datadir. Looking in template dir')
                    fname = nda.find_h5_file(jid, TEMPLATE_DIR)
                with h5.File(fname, 'r') as fd:
                    kc_st, kc_id = nda.get_event_times(fd[nda.kc_st_path])
                    kc_sc = np.array([len(st) for st in kc_st])
                    spiking_kc_count = len(np.flatnonzero(kc_sc))
                    spike_count = kc_sc.sum()
                    print('JID:', jid, ', spiking KCs:', spiking_kc_count, 'total spikes:', spike_count)
                    # For all cases, plot the result from the last successful KC removal, and first
                    if (kk == len(sim_set[ii]) - 2) or (kk == 1):
                        pn_st, pn_id = nda.get_event_times(fd[nda.pn_st_path])                    
                        ggn_vm, t = nda.get_ggn_vm(fd, 'basal')
                        ax[0, jj].plot(np.concatenate(pn_st), np.concatenate(pn_id), ',')
                        ax[1, jj].plot(np.concatenate(kc_st), np.concatenate(kc_id), ',')
                        ax[2, jj].plot(t, ggn_vm[0, :], label=f'{ii}: jj: {jid}')
                        kc_pop_st = np.concatenate(kc_st)
                        kc_pop_st.sort()
                        ax[3, jj].hist(kc_pop_st, bins=np.arange(0, t[-1], 100e-3),
                                       alpha=0.5, label=f'{ii}: {jj}: {jid}',
                                       histtype='step', linewidth=2)                    
                        scounts = [len(st) for st in kc_st]
                        if max(scounts) == 0:
                            continue
                        # bin starting with 1 to remove nonspiking KCs
                        ax_hist.hist(kc_sc, bins=np.arange(1, max(kc_sc)+0.5, 1),
                                     alpha=0.5, label=f'{ii}: {jj}: {jid}',
                                     histtype='step', linewidth=2)                    
        fig.set_size_inches(210/25.4, 290/25.4)
        fig.savefig(f'fig_kc_removal_jid_{SIX_SPIKES[ii][0]}.svg', transparent=True)
        fig_hist.savefig(f'fig_kc_hist_kc_removal_jid_{SIX_SPIKES[ii][0]}.svg', transparent=True)

    plt.show()
Example #3
0
def plot_spike_counts(ax, fname=None):        
    if fname is None:
        for ii in range(len(SIX_SPIKES)):
            spike_counts = []
            for jj, sim_set in enumerate([SIX_SPIKES, THREE_SPIKES, ALL_SPIKES]):
                sim_list = sim_set[ii][1:-1]
                for kk, jid in enumerate(sim_list):            
                    try:
                        fname = nda.find_h5_file(jid, DATA_DIR)
                    except:
                        # First entry is old data and moved to back up
                        # Template should still have all the data
                        print(f'JID {jid} not in datadir. Looking in template dir')
                        fname = nda.find_h5_file(jid, TEMPLATE_DIR)
                    with h5.File(fname, 'r') as fd:
                        kc_st, kc_id = nda.get_event_times(fd[nda.kc_st_path])                    
                        kc_sc = sum([len(st) for st in kc_st])
                        print('JID:', jid, 'total spikes:', kc_sc)
                        # For all cases, plot the result from the last successful KC removal, and first
                        spike_counts.append(kc_sc)
                ax.plot(len(spike_counts)-1, spike_counts[-1], 'k|')
            ax.plot(spike_counts, 'o-', fillstyle='none')
    else:
        total_spike_count = pd.read_csv(fname, sep=',')
        for ii, (series_id, simgrp) in enumerate(total_spike_count.groupby('series')):
            ax.plot(simgrp['total_spikes'].values, 'o-', fillstyle='none')
            series_df = simgrp.reset_index()
            for removal, remgrp in series_df.groupby('removal'):
                print('#', removal)
                print(remgrp)
#                jj += len(remgrp)
                ax.plot(remgrp.index.values[-1], remgrp['total_spikes'].values[-1], 'k.')
Example #4
0
def compare_data(leftfiles, rightfiles, leftheader, rightheader):
    """Compare two simulations side by side"""
    figs = []
    axeslist = []
    psthaxlist = []
    for left, right in zip(leftfiles, rightfiles):
        fig, axes = plt.subplots(nrows=6, ncols=2, sharey='row')
        psth_axes = []
        
        for ii, fname in enumerate([left, right]):
            fpath = os.path.join(datadir, fname)                         
            with h5.File(fpath, 'r') as fd:
                config = nda.load_config(fd)
                bins = np.arange(0, nda.get_simtime(fd)+0.5, 50.0)
                try:
                    pns = list(fd[nda.pn_st_path].keys())
                except KeyError:
                    print('Could not find PNs in', fname)
                    return figs, axeslist, psthaxlist
                pns = sorted(pns, key=lambda x: int(x.split('_')[-1]))
                pn_st, pn_y = nda.get_event_times(fd[nda.pn_st_path], pns)
                axes[0, ii].plot(np.concatenate(pn_st), np.concatenate(pn_y), ',')
                psth_ax = axes[0, ii].twinx()
                psth_axes.append(psth_ax)
                plot_population_psth(psth_ax, pn_st, config['pn']['number'], bins)
                lines, kc_st, kc_y = plot_kc_spikes_by_cluster(axes[1, ii], fd, 'LCA')
                plot_population_psth(axes[2, ii], kc_st, len(kc_st), bins, rate_sym='b^', cell_sym='rv')
                stiminfo = nda.get_stimtime(fd)
                stimend = stiminfo['onset'] + stiminfo['duration'] + stiminfo['offdur']
                rates = [len(st[(st > stiminfo['onset']) & (st < stimend)]) * 1e3 
                         / (stimend - stiminfo['onset']) for st in kc_st]
                print(rates[:5])
                axes[3, ii].hist(rates, bins=np.arange(21))
                axes[3, ii].set_xlabel('Firing rate')
                plot_kc_vm(axes[4, ii], fd, 'LCA', 5)
                plot_ggn_vm(axes[5, ii], fd,
                                   fd['/data/uniform/ggn_output/GGN_output_Vm'],
                                   'LCA', 5, color='r')
                plot_ggn_vm(axes[5, ii], fd,
                                   fd['/data/uniform/ggn_basal/GGN_basal_Vm'],
                                   'basal', 5, color='g')
                axes[5, ii].set_ylim((-53, -35))
                axes[0, ii].set_title('{}\nFAKE? {}'.format(fname, nda.load_config(fd)['kc']['fake_clusters']))
        time_axes = [axes[ii, jj] for ii in [0, 1, 2, 4, 5] for jj in [0, 1]]
        for ax in time_axes[:-1]:
            ax.set_xticks([])
        axes[0, 0].get_shared_x_axes().join(*time_axes)
        axes[2, 0].get_shared_x_axes().join(*axes[2, :])
        # psth_axes[0].get_shared_y_axes().join(*psth_axes)
        psth_axes[0].autoscale()
        # axes[-1, -1].autoscale()
        fig.text(0.1, 0.95, leftheader, ha='left', va='bottom')
        fig.text(0.6, 0.95, rightheader, ha='left', va='bottom')
        fig.set_size_inches(15, 10)
        # fig.tight_layout()
        figs.append(fig)
        axeslist.append(axes)
        psthaxlist.append(psth_axes)
    return figs, axeslist, psthaxlist
def make_psth_and_vm(ax_psth, ax_vm, ax_kc_hist):
    binwidth = 100
    datalist = (SIX_SPIKES[0][1:-1], THREE_SPIKES[0][1:-1],
                ALL_SPIKES[0][1:-1])
    colors = ['#e66101', '#5e3c99', '#009292']
    ls = ['-', ':']
    for ii, group in enumerate(datalist):
        print(group)
        for jj, jid in enumerate((group[0], group[-1])):
            print(jid)
            try:
                fname = nda.find_h5_file(jid, DATA_DIR)
            except:
                # First entry is old data and moved to back up
                # Template should still have all the data
                print(f'JID {jid} not in datadir. Looking in template dir')
                fname = nda.find_h5_file(jid, TEMPLATE_DIR)
            with h5.File(fname, 'r') as fd:
                kc_st, kc_id = nda.get_event_times(fd[nda.kc_st_path])
                kc_sc = np.array([len(st) for st in kc_st])
                try:
                    ax_kc_hist.hist(kc_sc,
                                    bins=np.arange(1,
                                                   max(kc_sc) + 0.5, 1),
                                    color=colors[ii],
                                    ls=ls[jj],
                                    label=f'{ii}: {jj}: {jid}',
                                    histtype='step',
                                    linewidth=1)
                except IndexError:
                    print(jid, ':', kc_sc, '|')

                pop_st = np.concatenate(kc_st)
                try:
                    ax_psth.hist(pop_st,
                                 bins=np.arange(500, 2100, binwidth),
                                 color=colors[ii],
                                 ls=ls[jj],
                                 histtype='step',
                                 label=jid)
                except IndexError:
                    print(jid, pop_st)
                ggn_vm, t = nda.get_ggn_vm(fd, 'basal')
                ax_vm.plot(t,
                           ggn_vm[0, :],
                           label=jid,
                           color=colors[ii],
                           ls=ls[jj])
    ax_psth.legend()
    ax_vm.legend()
Example #6
0
def plot_spike_rasters(fname, vm_samples=10, psth_bin_width=50.0,
                       kde_bw=50.0, by_cluster=False):
    """The file `fname` has data from pn_kc_ggn simulation. In the early
    ones I did not record the spike times for KCs. binwidths are in
    ms.

    """
    start = timer()
    print('Processing', fname)
    with h5.File(fname, 'r') as fd:
        try:
            print('Description:', fd.attrs['description'])
        except KeyError:
            print('No description available')
        try:
            config = yaml.load(fd.attrs['config'].decode())
        except KeyError:
            try:
                config = yaml.load(fd['model/filecontents/mb/network/config.yaml'][0].decode())
            except KeyError: # possibly fixed network - look for config in template
                try:
                    original = fd.attrs['original'].decode()
                    with h5.File(original, 'r') as origfd:
                        try:
                            config = yaml.load(origfd.attrs['config'].decode())
                        except AttributeError:  # Handling in Python3?
                            config = yaml.load(origfd.attrs['config'])
                except KeyError:
                    print('No config attribute or model file')
                    pass
        # PN spike raster
        pn_st = fd['/data/event/pn/pn_spiketime']
        fig, axes = plt.subplots(nrows=6, ncols=1, sharex=True)
        try:
            if 'calyx' in fd.attrs['description'].decode():
                descr = 'KC->GGN in alphaL + CA'
            else:
                descr = 'KC->GGN in alphaL only'
        except KeyError:
            descr = ''
        
        fig.suptitle('{} {}'.format(os.path.basename(fname), descr))
        ax_pn_spike_raster = axes[0]
        print('Plotting PN spikes')
        ax_pn_spike_raster.set_title('PN spike raster')
        nodes = [int(node.split('_')[-1]) for node in pn_st]
        nodes = ['pn_{}'.format(node) for node in sorted(nodes)]        
        spike_x, spike_y = nda.get_event_times(pn_st, nodes)
        ax_pn_spike_raster.plot(np.concatenate(spike_x), np.concatenate(spike_y), 'k,')
        simtime = Q_(config['stimulus']['onset']).to('ms').m +  \
                  Q_(config['stimulus']['duration']).to('ms').m + \
                  Q_(config['stimulus']['tail']).to('ms').m
        psth_bins = np.arange(0, simtime, psth_bin_width)
        kde_grid = np.linspace(0, simtime, 100.0)
        ax_pn_psth = ax_pn_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_pn_psth, spike_x,
                                         config['pn']['number'], psth_bins)
        plot_population_KDE(ax_pn_psth, spike_x, kde_grid, kde_bw, color='y',
                            maxamp=max(sr))
        ax_kc_lca_spike_raster = axes[1]
        print('Plotting KC PSTH in LCA')
        ax_kc_lca_spike_raster.set_title('KC LCA')
        if by_cluster:
            _, lca_spike_x, lca_spike_y = plot_kc_spikes_by_cluster(ax_kc_lca_spike_raster, fd, 'LCA')
        else:
            _, lca_spike_x, lca_spike_y = plot_kc_spikes(ax_kc_lca_spike_raster, fd, 'LCA')
        ax_kc_lca_psth = ax_kc_lca_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_kc_lca_psth, lca_spike_x, len(lca_spike_x), psth_bins)
        plot_population_KDE(ax_kc_lca_psth, lca_spike_x, kde_grid, kde_bw, color='y', maxamp=max(sr))
        ax_kc_lca_psth.legend()
        ax_kc_mca_spike_raster = axes[2]
        print('Plotting KC PSTH in MCA')
        ax_kc_mca_spike_raster.set_title('KC MCA')
        if by_cluster:
            _, mca_spike_x, mca_spike_y = plot_kc_spikes_by_cluster(ax_kc_mca_spike_raster, fd, 'MCA')
        else:
            _, mca_spike_x, mca_spike_y = plot_kc_spikes(ax_kc_mca_spike_raster, fd, 'MCA')
        ax_kc_mca_psth = ax_kc_mca_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_kc_mca_psth, mca_spike_x, len(mca_spike_x), psth_bins)
        plot_population_KDE(ax_kc_mca_psth, mca_spike_x, kde_grid, kde_bw, color='y', maxamp=max(sr))
        ax_kc_mca_psth.legend()
        # LCA KC Vm
        ax_kc_lca_vm = axes[3]
        ax_kc_lca_vm.set_title('KC LCA')
        plot_kc_vm(ax_kc_lca_vm, fd, 'LCA', vm_samples)
        # MCA KC Vm
        ax_kc_mca_vm = axes[4]
        ax_kc_mca_vm.set_title('KC MCA')
        plot_kc_vm(ax_kc_mca_vm, fd, 'MCA', vm_samples)
        # GGN MCA Vm, GGN LCA Vm
        ggn_vm_plot = axes[5]
        ggn_vm_plot.set_title('GGN Vm')
        ggn_output_vm = fd['/data/uniform/ggn_output/GGN_output_Vm']
        plot_ggn_vm(ggn_vm_plot, fd, ggn_output_vm, 'LCA', vm_samples, color='r')
        plot_ggn_vm(ggn_vm_plot, fd, ggn_output_vm, 'MCA', vm_samples, color='b')
        lca, = ggn_vm_plot.plot([], color='r', label='LCA')
        mca, = ggn_vm_plot.plot([], color='b', label='MCA')
        alpha, = ggn_vm_plot.plot([], color='k', label='alphaL')
        basal, = ggn_vm_plot.plot([], color='g', label='basal')
        ggn_vm_plot.legend(handles=[lca, mca, alpha, basal])
        # GGN alphaL Vm
        ggn_alphaL_vm = fd['/data/uniform/ggn_alphaL_input/GGN_alphaL_input_Vm']
        plot_ggn_vm(ggn_vm_plot, fd, ggn_alphaL_vm, 'alphaL', vm_samples, color='k')
        try:
            ggn_basal_vm = fd['/data/uniform/ggn_basal/GGN_basal_Vm']
            plot_ggn_vm(ggn_vm_plot, fd, ggn_basal_vm, 'basal', vm_samples, color='g')
        except KeyError:
            warnings.warn('No basal Vm recorded from GGN')
        end = timer()
        print('Time for plotting {}s'.format(end - start))
        return fig, axes
Example #7
0
gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[3, 1], hspace=0.05)
fig = plt.figure()
ax0 = fig.add_subplot(gs[0])
ax1 = fig.add_subplot(gs[1], sharex=ax0)
axes = [ax0, ax1]
with h5.File(fname, 'r') as fd:
    print('jid: {} spiking KCs: {}'.format(jid, len(nda.get_spiking_kcs(fd))))
    print(yaml.dump(nda.load_config(fd), default_style=''))
    # pn_st = []
    # pn_id = []
    # for pn in fd[nda.pn_st_path].values():
    #     pn_st.append(pn.value)
    #     pn_id.append([int(pn.name.rpartition('_')[-1])] * len(pn))
    # ax0.plot(np.concatenate(pn_st[::10]), np.concatenate(pn_id[::10]), 'k,')
    kc_x, kc_y = nda.get_event_times(fd[nda.kc_st_path])
    ax0.plot(np.concatenate(kc_x[::10]), np.concatenate(kc_y[::10]), 'k,')
    myplot.plot_ggn_vm(ax1,
                       fd,
                       fd['/data/uniform/ggn_basal/GGN_basal_Vm'],
                       'dend_b',
                       1,
                       color='k')
    ax1.set_ylim(-51, -45)
    ax1.set_xlim(700, 2500)
    xticks = np.array([
        200.0, 1000.0, 2000.0
    ]) + 500  # Here the onset was at 1s, 500 ms past the newer simulations
    ax1.set_xticks(xticks)
    ax1.set_xticklabels((xticks - 500) / 1000.0)
    ax1.set_yticks([-50.0, -40.0])