Ejemplo n.º 1
0
def getSortedTimes(dirName, chanGroup):

    dataio = DataIO(dirname=dirName)
    dataio.load_catalogue(chan_grp=chanGroup)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)

    sample_rate = dataio.sample_rate # Just initialize sample rate, will set later

    unitTimes = np.empty(dataio.nb_segment, dtype=object)

    for j in range(dataio.nb_segment):
        
        idd = {}
        times = {}

        try:
            # List of all cluster labels
            cluster_ids = np.array([i for i in catalogueconstructor.cluster_labels])
            # List of all detected peaks by cluster ID     
            clusters = np.array([i[1] for i in dataio.get_spikes(j)])

            spike_times = np.array([i[0] for i in dataio.get_spikes(j)])

        except:
            cluster_ids  = np.array([])
            clusters = np.array([])
            spike_times = np.array([])

        for i in cluster_ids:
            idd[i] = np.argwhere(clusters == i)

        for i in cluster_ids:
            times[i] = spike_times[idd[i]]/sample_rate  

        mx = np.max([times[i].size for i in times.keys()])

        for i in times.keys():
            times[i].resize(mx + 1, 1)

        timesArray = np.array([times[i] for i in times.keys()])

        timesArray = np.roll(timesArray, 1)
        timesArray[:, 0, :] = np.array(list(times.keys())).reshape(timesArray.shape[0], 1)

        timesArray = np.transpose(timesArray)

        unitTimes[j] = timesArray[0]
    
    return unitTimes
Ejemplo n.º 2
0
def apply_peeler():
    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)
    peeler = Peeler(dataio)
    peeler.change_params(catalogue=catalogue, chunksize=1024)

    peeler.run(progressbar=True)
Ejemplo n.º 3
0
def make_catalogue_figure():

    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)

    clusters = catalogue['clusters']

    geometry = dataio.get_geometry(chan_grp=0)

    fig, ax = plt.subplots()
    ax.set_title('Catalogue have 4 templates')
    for i in range(clusters.size):
        color = clusters[i]['color']
        color = int32_to_rgba(color, mode='float')

        waveforms = catalogue['centers0'][i:i + 1]

        plot_waveforms_with_geometry(waveforms,
                                     channels,
                                     geometry,
                                     ax=ax,
                                     ratioY=3,
                                     deltaX=50,
                                     margin=50,
                                     color=color,
                                     linewidth=3,
                                     alpha=1,
                                     show_amplitude=True,
                                     ratio_mad=8)

    fig.savefig('../img/peeler_templates_for_animation.png')
Ejemplo n.º 4
0
def open_PeelerWindow(dirname, chan_grp):
    dataio = DataIO(dirname=dirname)
    initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp)

    app = pg.mkQApp()
    win = PeelerWindow(dataio=dataio, catalogue=initial_catalogue)
    win.show()
    app.exec_()
Ejemplo n.º 5
0
def run_peeler(dirname, chan_grp):
    dataio = DataIO(dirname=dirname, ch_grp=chan_grp)
    initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp)

    peeler = Peeler(dataio)
    peeler.change_params(catalogue=initial_catalogue,
                         chunksize=32768,
                         use_sparse_template=False,
                         sparse_threshold_mad=1.5,
                         use_opencl_with_sparse=False)

    t1 = time.perf_counter()
    peeler.run()
    t2 = time.perf_counter()
    print('peeler.run', t2 - t1)
Ejemplo n.º 6
0
def run_peeler(dirname):
    dataio = DataIO(dirname=dirname)
    initial_catalogue = dataio.load_catalogue(chan_grp=0)

    peeler = Peeler(dataio)
    peeler.change_params(catalogue=initial_catalogue)

    t1 = time.perf_counter()
    peeler.run()
    t2 = time.perf_counter()
    print('peeler.run', t2 - t1)

    print()
    for seg_num in range(dataio.nb_segment):
        spikes = dataio.get_spikes(seg_num)
        print('seg_num', seg_num, 'nb_spikes', spikes.size)
Ejemplo n.º 7
0
def test_dataio_catalogue():
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
    
    dataio = DataIO(dirname='test_DataIO')
    
    catalogue = {}
    catalogue['chan_grp'] = 0
    catalogue['centers0'] = np.ones((300, 12, 50))
    
    catalogue['n_left'] = -15
    catalogue['params_signalpreprocessor'] = {'highpass_freq' : 300.}
    
    dataio.save_catalogue(catalogue, name='test')
    
    c2 = dataio.load_catalogue(name='test', chan_grp=0)
    print(c2)
    assert c2['n_left'] == -15
    assert np.all(c2['centers0']==1)
    assert catalogue['params_signalpreprocessor']['highpass_freq'] == 300.
Ejemplo n.º 8
0
def export_spikes(dirname, array_idx, chan_grp):
    print('Exporting ch %d' % chan_grp)
    data = {
        'array': [],
        'electrode': [],
        'cell': [],
        'segment': [],
        'time': []
    }
    array = cfg['arrays'][array_idx]

    dataio = DataIO(dirname=dirname, ch_grp=chan_grp)
    catalogue = dataio.load_catalogue(chan_grp=chan_grp)
    dataio._open_processed_data(ch_grp=chan_grp)

    clusters = catalogue['clusters']

    for seg_num in range(dataio.nb_segment):
        spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp)

        spike_labels = spikes['cluster_label'].copy()
        for l in clusters:
            mask = spike_labels == l['cluster_label']
            spike_labels[mask] = l['cell_label']
        spike_indexes = spikes['index']

        for (index, label) in zip(spike_indexes, spike_labels):
            if label >= 0:
                data['array'].append(array)
                data['electrode'].append(chan_grp)
                data['cell'].append(label)
                data['segment'].append(seg_num)
                data['time'].append(index)
        dataio.flush_processed_signals(seg_num=seg_num, chan_grp=chan_grp)
    df = pd.DataFrame(
        data, columns=['array', 'electrode', 'cell', 'segment', 'time'])
    df.to_csv(os.path.join(dirname, '%s_%d_spikes.csv' % (array, chan_grp)),
              index=False)
Ejemplo n.º 9
0
def make_animation():
    """
    Good example between 1.272 1.302
    because collision
    """

    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)

    clusters = catalogue['clusters']

    sr = dataio.sample_rate

    # also a good one a  11.356 - 11.366

    t1, t2 = 1.272, 1.295
    i1, i2 = int(t1 * sr), int(t2 * sr)

    spikes = dataio.get_spikes()
    spike_times = spikes['index'] / sr
    keep = (spike_times >= t1) & (spike_times <= t2)

    spikes = spikes[keep]
    print(spikes)

    sigs = dataio.get_signals_chunk(i_start=i1,
                                    i_stop=i2,
                                    signal_type='processed')
    sigs = sigs.copy()
    times = np.arange(sigs.shape[0]) / dataio.sample_rate

    def plot_spread_sigs(sigs, ax, ratioY=0.02, **kargs):
        #spread signals
        sigs2 = sigs * ratioY
        sigs2 += np.arange(0, len(channels))[np.newaxis, :]
        ax.plot(times, sigs2, **kargs)

        ax.set_ylim(-0.5, len(channels) - .5)
        ax.set_xticks([])
        ax.set_yticks([])

    residuals = sigs.copy()

    local_spikes = spikes.copy()
    local_spikes['index'] -= i1

    #~ fig, ax = plt.subplots()
    #~ plot_spread_sigs(sigs, ax, color='k')

    num_fig = 0

    fig_pred, ax_predictions = plt.subplots()
    ax_predictions.set_title('All detected templates from catalogue')

    fig, ax = plt.subplots()
    plot_spread_sigs(residuals, ax, color='k', lw=2)
    ax.set_title('Initial filtered signals with spikes')

    fig.savefig('../img/peeler_animation_sigs.png')

    fig.savefig('png/fig{}.png'.format(num_fig))
    num_fig += 1

    for i in range(local_spikes.size):
        label = local_spikes['cluster_label'][i]

        color = clusters[clusters['cluster_label'] == label]['color'][0]
        color = int32_to_rgba(color, mode='float')

        pred = make_prediction_signals(local_spikes[i:i + 1], 'float32',
                                       (i2 - i1, len(channels)), catalogue)

        fig, ax = plt.subplots()
        plot_spread_sigs(residuals, ax, color='k', lw=2)
        plot_spread_sigs(pred, ax, color=color, lw=1.5)
        ax.set_title('Dected spike label {}'.format(label))

        fig.savefig('png/fig{}.png'.format(num_fig))
        num_fig += 1

        residuals -= pred

        plot_spread_sigs(pred, ax_predictions, color=color, lw=1.5)

        fig, ax = plt.subplots()
        plot_spread_sigs(residuals, ax, color='k', lw=2)
        plot_spread_sigs(pred, ax, color=color, lw=1, ls='--')
        ax.set_title('New residual after substraction')

        fig.savefig('png/fig{}.png'.format(num_fig))
        num_fig += 1

    fig_pred.savefig('png/fig{}.png'.format(num_fig))
    num_fig += 1
Ejemplo n.º 10
0
def generate_spike_sorting_report(subject, recording_date):

    data_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject,
                            recording_date)
    if os.path.exists(data_dir):
        channel_results = []

        for array_idx in range(len(cfg['arrays'])):
            array = cfg['arrays'][array_idx]
            print(array)

            array_data_dir = os.path.join(data_dir, 'array_%d' % array_idx)

            if os.path.exists(array_data_dir):
                export_path = os.path.join(array_data_dir, 'figures')
                if not os.path.exists(export_path):
                    os.makedirs(export_path)

                for chan_grp in range(cfg['n_channels_per_array']):
                    print(chan_grp)

                    dataio = DataIO(array_data_dir, ch_grp=chan_grp)
                    dataio.datasource.bit_to_microVolt = 0.195
                    catalogueconstructor = CatalogueConstructor(
                        dataio=dataio, chan_grp=chan_grp)
                    catalogueconstructor.refresh_colors()
                    catalogue = dataio.load_catalogue(chan_grp=chan_grp)

                    channel_result = {
                        'array': array,
                        'channel': chan_grp,
                        'init_waveforms': '',
                        'clean_waveforms': '',
                        'noise': '',
                        'init_clusters': '',
                        'merge_clusters': [],
                        'final_clusters': [],
                        'all_clusters': ''
                    }

                    clusters = catalogue['clusters']

                    cluster_labels = clusters['cluster_label']
                    cell_labels = clusters['cell_label']

                    channel_result['init_waveforms'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_init_waveforms.png' % chan_grp)
                    channel_result['clean_waveforms'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_clean_waveforms.png' % chan_grp)
                    channel_result['noise'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_noise.png' % chan_grp)
                    channel_result['init_clusters'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_init_clusters.png' % chan_grp)

                    merge_files = glob.glob(
                        os.path.join(export_path,
                                     'chan_%d_merge_*.png' % chan_grp))
                    for merge_file in merge_files:
                        [path, file] = os.path.split(merge_file)
                        channel_result['merge_clusters'].append(
                            os.path.join('array_%d' % array_idx, 'figures',
                                         file))

                    for cluster_label in cluster_labels:
                        fig = plot_cluster_summary(dataio, catalogue, chan_grp,
                                                   cluster_label)
                        fname = 'chan_%d_cluster_%d.png' % (chan_grp,
                                                            cluster_label)
                        fig.savefig(os.path.join(export_path, fname))
                        fig.clf()
                        plt.close()
                        channel_result['final_clusters'].append(
                            os.path.join('array_%d' % array_idx, 'figures',
                                         fname))

                    fig = plot_clusters_summary(dataio, catalogueconstructor,
                                                chan_grp)
                    fname = 'chan_%d_clusters.png' % chan_grp
                    fig.savefig(os.path.join(export_path, fname))
                    fig.clf()
                    plt.close()
                    channel_result['all_clusters'] = os.path.join(
                        'array_%d' % array_idx, 'figures', fname)

                    channel_results.append(channel_result)

        env = Environment(loader=FileSystemLoader(cfg['template_dir']))
        template = env.get_template('spike_sorting_results_template.html')
        template_output = template.render(subject=subject,
                                          recording_date=recording_date,
                                          channel_results=channel_results)

        out_filename = os.path.join(data_dir, 'spike_sorting_report.html')
        with open(out_filename, 'w') as fh:
            fh.write(template_output)

        copyfile(os.path.join(cfg['template_dir'], 'style.css'),
                 os.path.join(data_dir, 'style.css'))