Ejemplo n.º 1
0
def make_catalogue_figure():

    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)

    clusters = catalogue['clusters']

    geometry = dataio.get_geometry(chan_grp=0)

    fig, ax = plt.subplots()
    ax.set_title('Catalogue have 4 templates')
    for i in range(clusters.size):
        color = clusters[i]['color']
        color = int32_to_rgba(color, mode='float')

        waveforms = catalogue['centers0'][i:i + 1]

        plot_waveforms_with_geometry(waveforms,
                                     channels,
                                     geometry,
                                     ax=ax,
                                     ratioY=3,
                                     deltaX=50,
                                     margin=50,
                                     color=color,
                                     linewidth=3,
                                     alpha=1,
                                     show_amplitude=True,
                                     ratio_mad=8)

    fig.savefig('../img/peeler_templates_for_animation.png')
Ejemplo n.º 2
0
def apply_peeler():
    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)
    peeler = Peeler(dataio)
    peeler.change_params(catalogue=catalogue, chunksize=1024)

    peeler.run(progressbar=True)
Ejemplo n.º 3
0
def preprocess_array(array_idx, output_dir, total_duration):
    dataio = DataIO(dirname=output_dir, ch_grp=array_idx)
    fullchain_kargs = {
        'duration': total_duration,
        'preprocessor': {
            'highpass_freq': 250.,
            'lowpass_freq': 3000.,
            'smooth_size': 0,
            'common_ref_removal': True,
            'chunksize': 32768,
            'lostfront_chunksize': 0,
            'signalpreprocessor_engine': 'numpy',
        }
    }
    cc = CatalogueConstructor(dataio=dataio, chan_grp=array_idx)
    p = {}
    p.update(fullchain_kargs['preprocessor'])
    cc.set_preprocessor_params(**p)
    # TODO offer noise esatimation duration somewhere
    noise_duration = min(
        10., fullchain_kargs['duration'],
        dataio.get_segment_length(seg_num=0) / dataio.sample_rate * .99)
    # ~ print('noise_duration', noise_duration)
    t1 = time.perf_counter()
    cc.estimate_signals_noise(seg_num=0, duration=noise_duration)
    t2 = time.perf_counter()
    print('estimate_signals_noise', t2 - t1)
    t1 = time.perf_counter()
    cc.run_signalprocessor(duration=fullchain_kargs['duration'],
                           detect_peak=False)
    t2 = time.perf_counter()
    print('run_signalprocessor', t2 - t1)
Ejemplo n.º 4
0
def test_good_events():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    #~ peak_pos = peak_pos[:100]
    #~ print(peak_pos)
    
    waveforms = extract_peak_waveforms(sigs, peak_pos, peak_pos,  -30,50)
    keep = good_events(waveforms,upper_thr=5.,lower_thr=-5.)
    #~ print(keep)
    goods_wf = waveforms[keep]
    bads_wf = waveforms[~keep]


    fig, ax = pyplot.subplots()
    #~ goods_wf.transpose().plot(ax =ax, color = 'g', lw = .3)
    bads_wf.transpose().plot(ax =ax,color = 'r', lw = .3)
    
    
    med = waveforms.median(axis=0)
    mad = np.median(np.abs(waveforms-med),axis=0)*1.4826
    limit1 = med+5*mad
    limit2 = med-5*mad
    
    med.plot(ax = ax, color = 'm', lw = 2)
    limit1.plot(ax = ax, color = 'm')
    limit2.plot(ax = ax, color = 'm')
def test_clustering():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    short_wf = waveformextractor.get_ajusted_waveforms()
    print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    
    #PCA
    features = clustering.project(method = 'pca', n_components = 5)
    
    clustering.plot_projection(plot_density = True)
    
    #Kmean
    labels = clustering.find_clusters(7)
    clustering.plot_projection(plot_density = True)
    
    #ùake catalogue
    catalogue = clustering.construct_catalogue()
    clustering.plot_derivatives()
    clustering.plot_catalogue()


    clustering.merge_cluster(1,2)
    clustering.split_cluster(1, 2)
Ejemplo n.º 6
0
def test_rectify_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    retified_sigs = rectify_signals(normalize_signals(sigs), threshold = -4)
    
    fig, ax = pyplot.subplots()
    retified_sigs[3.14:3.22].plot(ax = ax)
    ax.set_ylim(-20, 10)
Ejemplo n.º 7
0
def open_PeelerWindow(dirname, chan_grp):
    dataio = DataIO(dirname=dirname)
    initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp)

    app = pg.mkQApp()
    win = PeelerWindow(dataio=dataio, catalogue=initial_catalogue)
    win.show()
    app.exec_()
Ejemplo n.º 8
0
def test_peakdetector():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs, seg_num=0)
    peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    print(peakdetector.peak_pos.size)
    peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 5)
    print(peakdetector.peak_pos.size)
Ejemplo n.º 9
0
def test_DataIO_probes():
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    
    probe_filename = 'A4x8-5mm-100-400-413-A32.prb'
    dataio.download_probe(probe_filename)
    dataio.download_probe('A4x8-5mm-100-400-413-A32')
    
    #~ print(dataio.channel_groups)
    #~ print(dataio.channels)
    #~ print(dataio.info['probe_filename'])
    
    assert dataio.nb_channel(0) == 8
    assert probe_filename == dataio.info['probe_filename']
    
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
Ejemplo n.º 10
0
def test_peeler():
    dataio = DataIO(dirname = 'datatest')
    #~ dataio = DataIO(dirname = 'datatest_neo')
    
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms()
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 4)
    clustering.find_clusters(8, order_clusters = True)
    catalogue = clustering.construct_catalogue()
    #~ clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    #~ clustering.merge_cluster(1, 2)
    catalogue = clustering.construct_catalogue()
    clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    
    #peeler
    signals = peakdetector.normed_sigs
    peeler = Peeler(signals, catalogue,  limit_left, limit_right,
                            threshold=-5., peak_sign = '-', n_span = 5)
    
    prediction0, residuals0 = peeler.peel()
    prediction1, residuals1 = peeler.peel()
    
    spiketrains = peeler.get_spiketrains()
    print(spiketrains)
    
    fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True)
    axs[0].plot(signals)
    axs[1].plot(prediction0) 
    axs[2].plot(residuals0)
    axs[3].plot(prediction1)
    axs[4].plot(residuals1)
    
    for i in range(5):
        axs[i].set_ylim(-25, 10)
    
    peeler.plot_spiketrains(ax = axs[5])
Ejemplo n.º 11
0
def test_DataIO():
    
    
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    #~ dataio.set_channels(range(4))
    dataio.set_manual_channel_group(range(14))
    
    
    for seg_num in range(dataio.nb_segment):
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
            #~ print(seg_num, i_stop, sigs_chunk.shape)
    
    
    #reopen existing
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
    
    #~ exit()
    
    for seg_num in range(dataio.nb_segment):
        #~ print('seg_num', seg_num)
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
Ejemplo n.º 12
0
def test_extract_noise_waveforms():

    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    
    waveforms = extract_noise_waveforms(sigs, peak_pos, -30,50, size = 500)
    print(waveforms.shape)
    
    fig, ax = pyplot.subplots()
    waveforms.median(axis=0).plot(ax =ax)
Ejemplo n.º 13
0
def run_peeler(dirname, chan_grp):
    dataio = DataIO(dirname=dirname, ch_grp=chan_grp)
    initial_catalogue = dataio.load_catalogue(chan_grp=chan_grp)

    peeler = Peeler(dataio)
    peeler.change_params(catalogue=initial_catalogue,
                         chunksize=32768,
                         use_sparse_template=False,
                         sparse_threshold_mad=1.5,
                         use_opencl_with_sparse=False)

    t1 = time.perf_counter()
    peeler.run()
    t2 = time.perf_counter()
    print('peeler.run', t2 - t1)
Ejemplo n.º 14
0
def test_filter():
    dataio = DataIO(dirname = 'datatest_neo')
    sigs = dataio.get_signals(seg_num=0, signal_type = 'unfiltered')
    
    filter =  SignalFilter(sigs, highpass_freq = 300.)
    filterred_sigs = filter.get_filtered_data()

    filter2 =  SignalFilter(sigs, highpass_freq = 300., box_smooth = 3)
    filterred_sigs2 = filter2.get_filtered_data()
    
    
    fig, axs = pyplot.subplots(nrows=3, sharex = True)
    sigs[0:.5].plot(ax=axs[0])
    filterred_sigs[0:.5].plot(ax=axs[1])
    filterred_sigs2[0:.5].plot(ax=axs[2])
Ejemplo n.º 15
0
def run_merge_clusters(subject, recording_date, ch_grp, threshold):
    data_dir = os.path.join(cfg['intan_data_dir'], subject, recording_date)
    print(data_dir)
    if os.path.exists(data_dir):
        # Compute total duration (want to use all data for clustering)
        data_file_names = []
        for x in os.listdir(data_dir):
            if os.path.splitext(x)[1] == '.raw':
                data_file_names.append(os.path.join(data_dir, x))

        data_file_times = []
        for idx, evt_file in enumerate(data_file_names):
            fname = os.path.split(evt_file)[-1]
            fparts = fname.split('_')
            filedate = datetime.strptime(
                '%s %s' % (fparts[4], fparts[5].split('.')[0]),
                '%y%m%d %H%M%S')
            data_file_times.append(filedate)
        data_file_names = [
            x for _, x in sorted(zip(data_file_times, data_file_names))
        ]

        if os.path.exists(data_dir) and len(data_file_names) > 0:
            output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'],
                                      subject, recording_date)

            ## Setup DataIO
            dataio = DataIO(dirname=output_dir)
            #dataio.set_data_source(type='RawData', filenames=data_file_names, dtype='float32', sample_rate=30000,total_channel=192)

            merge_clusters(dataio, ch_grp, output_dir, threshold)
Ejemplo n.º 16
0
def preprocess_signals_and_peaks(dirname):
    dataio = DataIO(dirname=dirname)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)

    catalogueconstructor.set_preprocessor_params(chunksize=1024,

                                                 # signal preprocessor
                                                 highpass_freq=250,
                                                 lowpass_freq=5000,
                                                 smooth_size=0,
                                                 common_ref_removal=False,
                                                 lostfront_chunksize=0,
                                                 signalpreprocessor_engine='opencl',

                                                 # peak detector
                                                 peakdetector_engine='opencl',
                                                 peak_sign='-',
                                                 relative_threshold=5,
                                                 peak_span=0.0002
                                                 )

    t1 = time.perf_counter()
    catalogueconstructor.estimate_signals_noise(seg_num=0, duration=10)
    t2 = time.perf_counter()
    print('estimate_signals_noise', t2 - t1)
    print(catalogueconstructor.signals_medians)
    print(catalogueconstructor.signals_mads)

    t1 = time.perf_counter()
    catalogueconstructor.run_signalprocessor(duration=300)
    t2 = time.perf_counter()
    print('run_signalprocessor', t2 - t1)

    print(catalogueconstructor)
Ejemplo n.º 17
0
def run_peeler(dirname):
    dataio = DataIO(dirname=dirname)
    initial_catalogue = dataio.load_catalogue(chan_grp=0)

    peeler = Peeler(dataio)
    peeler.change_params(catalogue=initial_catalogue)

    t1 = time.perf_counter()
    peeler.run()
    t2 = time.perf_counter()
    print('peeler.run', t2 - t1)

    print()
    for seg_num in range(dataio.nb_segment):
        spikes = dataio.get_spikes(seg_num)
        print('seg_num', seg_num, 'nb_spikes', spikes.size)
Ejemplo n.º 18
0
def test_dataio_with_neo():
    if os.path.exists('datatest_neo/data.h5'):
        os.remove('datatest_neo/data.h5')
    dataio = DataIO(dirname = 'datatest_neo')
    
    import neo
    import quantities as pq
    
    filenames = ['Tem06c06.IOT', 'Tem06c07.IOT', 'Tem06c08.IOT', ]
    for filename in filenames:
        blocks = neo.RawBinarySignalIO(filename).read(sampling_rate = 10.*pq.kHz,
                        t_start = 0. *pq.S, unit = pq.V, nbchannel = 16, bytesoffset = 0,
                        dtype = 'int16', rangemin = -10, rangemax = 10)
        channel_indexes = np.arange(14)
        dataio.append_signals_from_neo(blocks, channel_indexes = channel_indexes, signal_type = 'unfiltered')
    print(dataio.summary(level=1))
Ejemplo n.º 19
0
def test_extract_peak_waveforms():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    peak_index = sigs.index[peak_pos]
    
    waveforms = extract_peak_waveforms(sigs, peak_pos,peak_index,  -30,50)
    print(waveforms.shape)
    fig, ax = pyplot.subplots()
    waveforms.median(axis=0).plot(ax =ax)
    
    normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index,  -15,50)
    fig, ax = pyplot.subplots()
    normed_waveforms.median(axis=0).plot(ax =ax)
Ejemplo n.º 20
0
def open_cataloguewindow(dirname):
    dataio = DataIO(dirname=dirname)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)

    app = pg.mkQApp()
    win = CatalogueWindow(catalogueconstructor)
    win.show()

    app.exec_()
Ejemplo n.º 21
0
def test_waveform_extractor():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    print(limit_left, limit_right)
    
    long_wf = waveformextractor.long_waveforms
    short_wf = waveformextractor.get_ajusted_waveforms()
    
    assert long_wf.shape[1]>short_wf.shape[1]
    
    waveformextractor.plot_good_limit()
Ejemplo n.º 22
0
def test_peeler():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms(margin=2)
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 5)
    clustering.find_clusters(7)
    catalogue = clustering.construct_catalogue()
    
    clustering.plot_catalogue()
    
    #peeler
    signals = peakdetector.normed_sigs
    peeler = Peeler(signals, catalogue,  limit_left, limit_right,
                            threshold=-4, peak_sign = '-', n_span = 5)
    
    prediction0, residuals0 = peeler.peel()
    prediction1, residuals1 = peeler.peel()
    fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True)
    axs[0].plot(signals)
    axs[1].plot(prediction0) 
    axs[2].plot(residuals0)
    axs[3].plot(prediction1)
    axs[4].plot(residuals1)
    
    colors = sns.color_palette('husl', len(catalogue))
    spiketrains = peeler.get_spiketrains()
    i = 0
    for k , pos in spiketrains.items():
        axs[5].plot(pos, np.ones(pos.size)*k, ls = 'None', marker = '|',  markeredgecolor = colors[i], markersize = 10, markeredgewidth = 2)
        i += 1
    axs[5].set_ylim(0, len(catalogue))
Ejemplo n.º 23
0
def test_dataio():
    if os.path.exists('datatest/data.h5'):
        os.remove('datatest/data.h5')
    dataio = DataIO(dirname = 'datatest')
    #~ print(data)
    #data from locust
    sigs_by_trials, sampling_rate, ch_names = download_locust(trial_names = ['trial_01', 'trial_02', 'trial_03'])
    
    
    for seg_num in range(3):
        sigs = sigs_by_trials[seg_num]
        dataio.append_signals_from_numpy(sigs, seg_num = seg_num,t_start = 0.+5*seg_num, sampling_rate =  sampling_rate,
                    signal_type = 'filtered', channels = ch_names)
    
    #~ print(data)
    #~ print(data.segments)
    #~ print(data.store)
    print(dataio.summary(level=0))
    print(dataio.summary(level=1))
Ejemplo n.º 24
0
def plot_interpolation():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms(margin=2)
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 5)
    clustering.find_clusters(7)
    catalogue = clustering.construct_catalogue()
    
    k = list(catalogue.keys())[1]
    w0 = catalogue[k]['center']
    w1 = catalogue[k]['centerD']
    w2 = catalogue[k]['centerDD']
    
    fig, ax = pyplot.subplots()
    t = np.arange(w0.size)
    
    colors = sns.color_palette('husl', 12)
    
    all = []
    jitters = np.arange(-.5,.5,.1)
    for i, jitter in enumerate(jitters):
        pred = w0 + jitter*w1 + jitter**2/2.*w2
        all.append(pred)
        ax.plot(t+jitter, pred, marker = 'o', label = str(jitter), color = colors[i], linestyle = 'None')
    ax.plot(t, w0, marker = '*', markersize = 4, label = 'w0', lw = 1, color = 'k')     

    all = np.array(all)
    interpolated = all.transpose().flatten()
    t2 = np.arange(interpolated.size)/all.shape[0] + jitters[0]
    ax.plot(t2, interpolated, label = 'interp', lw = 1, color = 'm')
Ejemplo n.º 25
0
def test_dataio_catalogue():
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
    
    dataio = DataIO(dirname='test_DataIO')
    
    catalogue = {}
    catalogue['chan_grp'] = 0
    catalogue['centers0'] = np.ones((300, 12, 50))
    
    catalogue['n_left'] = -15
    catalogue['params_signalpreprocessor'] = {'highpass_freq' : 300.}
    
    dataio.save_catalogue(catalogue, name='test')
    
    c2 = dataio.load_catalogue(name='test', chan_grp=0)
    print(c2)
    assert c2['n_left'] == -15
    assert np.all(c2['centers0']==1)
    assert catalogue['params_signalpreprocessor']['highpass_freq'] == 300.
Ejemplo n.º 26
0
def test_detect_peak_method_span():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    normed_sigs = normalize_signals(sigs)
    retified_sigs = rectify_signals(normed_sigs, threshold = -4)    
    peaks_pos = detect_peak_method_span(retified_sigs,  peak_sign='-', n_span = 5)
    peaks_index = sigs.index[peaks_pos]
    
    fig, ax = pyplot.subplots()
    chunk = retified_sigs[3.14:3.22]
    chunk.plot(ax = ax)
    peaks_value = retified_sigs.loc[peaks_index]
    peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k')
    ax.set_ylim(-20, 10)
    
    fig, ax = pyplot.subplots()
    chunk = normed_sigs[3.14:3.22]
    chunk.plot(ax = ax)
    peaks_value = normed_sigs.loc[peaks_index]
    peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k')
    ax.set_ylim(-20, 10)
Ejemplo n.º 27
0
def test_find_good_limits():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    peak_index = sigs.index[peak_pos]
    
    normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -25,50)
    
    normed_med = normed_waveforms.median(axis=0)
    normed_mad = np.median(np.abs(normed_waveforms-normed_med),axis=0)*1.4826
    normed_mad = normed_mad.reshape(4,-1)
    
    fig, ax = pyplot.subplots()
    ax.plot(normed_mad.transpose())
    ax.axhline(1.1)
    
    l1, l2 = find_good_limits(normed_mad)
    print(l1,l2)
    ax.axvline(l1)
    ax.axvline(l2)
Ejemplo n.º 28
0
def make_pca_collision_figure():

    dataio = DataIO(dirname=dirname)
    cc = CatalogueConstructor(dataio=dataio)

    clusters = cc.clusters
    #~ plot_features_scatter_2d(cc, labels=None, nb_max=500)

    #~ plot_features_scatter_2d

    fig, ax = plt.subplots()
    ax.set_title('Collision problem')
    ax.set_aspect('equal')
    features = cc.some_features

    labels = cc.all_peaks[cc.some_peaks_index]['cluster_label']

    for k in [0, 1, 2, 3]:
        color = clusters[clusters['cluster_label'] == k]['color'][0]
        color = int32_to_rgba(color, mode='float')

        keep = labels == k
        feat = features[keep]

        print(np.unique(labels))

        ax.plot(feat[:, 0],
                feat[:, 1],
                ls='None',
                marker='o',
                color=color,
                markersize=3,
                alpha=.5)

    ax.set_xlim(-40, 40)
    ax.set_ylim(-40, 40)

    ax.set_xlabel('pca0')
    ax.set_ylabel('pca1')

    ax.annotate('Collision',
                xy=(17.6, -16.4),
                xytext=(30, -30),
                arrowprops=dict(facecolor='black', shrink=0.05))

    #~

    fig.savefig('../img/collision_proble_pca.png')
Ejemplo n.º 29
0
def clean_and_save_catalogue(dirname):
    dataio = DataIO(dirname=dirname)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)

    catalogueconstructor.trash_small_cluster(n=5)

    # order cluster by waveforms rms
    catalogueconstructor.order_clusters(by='waveforms_rms')

    # put label 0 to trash
    mask = catalogueconstructor.all_peaks['cluster_label'] == 0
    catalogueconstructor.all_peaks['cluster_label'][mask] = -1
    catalogueconstructor.on_new_cluster()

    # save the catalogue
    catalogueconstructor.make_catalogue_for_peeler()
Ejemplo n.º 30
0
def extract_waveforms_pca_cluster(dirname):
    dataio = DataIO(dirname=dirname)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)
    print(catalogueconstructor)

    t1 = time.perf_counter()
    # ~ catalogueconstructor.extract_some_waveforms(n_left=-35, n_right=150,  nb_max=10000, align_waveform=True, subsample_ratio=20)
    catalogueconstructor.extract_some_waveforms(n_left=-20, n_right=30, nb_max=20000, align_waveform=False)
    t2 = time.perf_counter()
    print('extract_some_waveforms', t2 - t1)
    # ~ print(catalogueconstructor.some_waveforms.shape)
    print(catalogueconstructor)

    # ~ t1 = time.perf_counter()
    # ~ n_left, n_right = catalogueconstructor.find_good_limits(mad_threshold = 1.1,)
    # ~ t2 = time.perf_counter()
    # ~ print('n_left', n_left, 'n_right', n_right)
    # ~ print(catalogueconstructor.some_waveforms.shape)
    print(catalogueconstructor)

    # ~ print(catalogueconstructor.all_peaks)
    # ~ exit()

    t1 = time.perf_counter()
    catalogueconstructor.clean_waveforms(alien_value_threshold=100.)
    t2 = time.perf_counter()
    print('clean_waveforms', t2 - t1)

    # extract_some_noise
    t1 = time.perf_counter()
    catalogueconstructor.extract_some_noise(nb_snippet=300)
    t2 = time.perf_counter()
    print('extract_some_noise', t2 - t1)

    t1 = time.perf_counter()
    catalogueconstructor.project(method='pca_by_channel', n_components_by_channel=3)
    # ~ catalogueconstructor.project(method='tsne', n_components=2, perplexity=40., init='pca')
    t2 = time.perf_counter()
    print('project', t2 - t1)
    print(catalogueconstructor)

    t1 = time.perf_counter()
    catalogueconstructor.find_clusters(method='gmm', n_clusters=3*192)
    t2 = time.perf_counter()
    print('find_clusters', t2 - t1)
    print(catalogueconstructor)
Ejemplo n.º 31
0
def make_catalogue():
    if os.path.exists(dirname):
        shutil.rmtree(dirname)

    dataio = DataIO(dirname=dirname)
    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames, **params)
    dataio.add_one_channel_group(channels=channels)

    cc = CatalogueConstructor(dataio=dataio)

    params = {
        'duration': 300.,
        'preprocessor': {
            'highpass_freq': 300.,
            'chunksize': 1024,
            'lostfront_chunksize': 100,
        },
        'peak_detector': {
            'peak_sign': '-',
            'relative_threshold': 7.,
            'peak_span': 0.0005,
            #~ 'peak_span' : 0.000,
        },
        'extract_waveforms': {
            'n_left': -25,
            'n_right': 40,
            'nb_max': 10000,
        },
        'clean_waveforms': {
            'alien_value_threshold': 60.,
        },
        'noise_snippet': {
            'nb_snippet': 300,
        },
        'feature_method': 'global_pca',
        'feature_kargs': {
            'n_components': 20
        },
        'cluster_method': 'kmeans',
        'cluster_kargs': {
            'n_clusters': 5
        },
        'clean_cluster': False,
        'clean_cluster_kargs': {},
    }

    apply_all_catalogue_steps(cc, params, verbose=True)

    cc.order_clusters(by='waveforms_rms')
    cc.move_cluster_to_trash(4)
    cc.make_catalogue_for_peeler()
Ejemplo n.º 32
0
def preprocess_data(subject, recording_date, data_files):
    output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject,
                              recording_date, 'preprocess')
    if os.path.exists(output_dir):
        # remove is already exists
        shutil.rmtree(output_dir)

    ## Setup DataIO
    dataio = DataIO(dirname=output_dir)
    dataio.set_data_source(type='Intan',
                           filenames=[x['fname'] for x in data_files])

    # Setup channel groups
    arrays_recorded = []
    grp_idx = 0
    for array_idx in range(len(cfg['arrays'])):
        first_chan = ''
        if array_idx == 0:
            first_chan = 'A-000'
        elif array_idx == 1:
            first_chan = 'A-032'
        elif array_idx == 2:
            first_chan = 'B-000'
        elif array_idx == 3:
            first_chan = 'B-032'
        elif array_idx == 4:
            first_chan = 'C-000'
        elif array_idx == 5:
            first_chan = 'C-032'
        found = False
        for i in range(len(dataio.datasource.sig_channels)):
            if dataio.datasource.sig_channels[i][0] == first_chan:
                found = True
                break

        chan_range = []
        if found:
            chan_range = range(grp_idx * cfg['n_channels_per_array'],
                               (grp_idx + 1) * cfg['n_channels_per_array'])
            grp_idx = grp_idx + 1
            arrays_recorded.append(array_idx)
        dataio.add_one_channel_group(channels=chan_range, chan_grp=array_idx)

    print(dataio)

    total_duration = np.sum([x['duration'] for x in data_files])
    for array_idx in arrays_recorded:
        print(array_idx)
        preprocess_array(array_idx, output_dir, total_duration)
Ejemplo n.º 33
0
def getSortedTimes(dirName, chanGroup):

    dataio = DataIO(dirname=dirName)
    dataio.load_catalogue(chan_grp=chanGroup)
    catalogueconstructor = CatalogueConstructor(dataio=dataio)

    sample_rate = dataio.sample_rate # Just initialize sample rate, will set later

    unitTimes = np.empty(dataio.nb_segment, dtype=object)

    for j in range(dataio.nb_segment):
        
        idd = {}
        times = {}

        try:
            # List of all cluster labels
            cluster_ids = np.array([i for i in catalogueconstructor.cluster_labels])
            # List of all detected peaks by cluster ID     
            clusters = np.array([i[1] for i in dataio.get_spikes(j)])

            spike_times = np.array([i[0] for i in dataio.get_spikes(j)])

        except:
            cluster_ids  = np.array([])
            clusters = np.array([])
            spike_times = np.array([])

        for i in cluster_ids:
            idd[i] = np.argwhere(clusters == i)

        for i in cluster_ids:
            times[i] = spike_times[idd[i]]/sample_rate  

        mx = np.max([times[i].size for i in times.keys()])

        for i in times.keys():
            times[i].resize(mx + 1, 1)

        timesArray = np.array([times[i] for i in times.keys()])

        timesArray = np.roll(timesArray, 1)
        timesArray[:, 0, :] = np.array(list(times.keys())).reshape(timesArray.shape[0], 1)

        timesArray = np.transpose(timesArray)

        unitTimes[j] = timesArray[0]
    
    return unitTimes
Ejemplo n.º 34
0
def export_spikes(dirname, array_idx, chan_grp):
    print('Exporting ch %d' % chan_grp)
    data = {
        'array': [],
        'electrode': [],
        'cell': [],
        'segment': [],
        'time': []
    }
    array = cfg['arrays'][array_idx]

    dataio = DataIO(dirname=dirname, ch_grp=chan_grp)
    catalogue = dataio.load_catalogue(chan_grp=chan_grp)
    dataio._open_processed_data(ch_grp=chan_grp)

    clusters = catalogue['clusters']

    for seg_num in range(dataio.nb_segment):
        spikes = dataio.get_spikes(seg_num=seg_num, chan_grp=chan_grp)

        spike_labels = spikes['cluster_label'].copy()
        for l in clusters:
            mask = spike_labels == l['cluster_label']
            spike_labels[mask] = l['cell_label']
        spike_indexes = spikes['index']

        for (index, label) in zip(spike_indexes, spike_labels):
            if label >= 0:
                data['array'].append(array)
                data['electrode'].append(chan_grp)
                data['cell'].append(label)
                data['segment'].append(seg_num)
                data['time'].append(index)
        dataio.flush_processed_signals(seg_num=seg_num, chan_grp=chan_grp)
    df = pd.DataFrame(
        data, columns=['array', 'electrode', 'cell', 'segment', 'time'])
    df.to_csv(os.path.join(dirname, '%s_%d_spikes.csv' % (array, chan_grp)),
              index=False)
Ejemplo n.º 35
0
def initialize_catalogueconstructor(dirname, filenames):
    # create a DataIO
    if os.path.exists(dirname):
        # remove is already exists
        shutil.rmtree(dirname)
    dataio = DataIO(dirname=dirname)

    # The dataset contains 4 channels : we use them all
    #dataio.set_channel_groups({'channels':{'channels':[0, 1, 2, 3]}})

    # feed DataIO
    dataio.set_data_source(type='Intan', filenames=filenames, channel_indexes=list(range(192)))
    #dataio.set_probe_file('/home/bonaiuto/Projects/tool_learning/recordings/rhd2000/betta/default.prb')

    dataio.add_one_channel_group(channels=range(192), chan_grp=0)

    print(dataio)
Ejemplo n.º 36
0
def generate_spike_sorting_report(subject, recording_date):

    data_dir = os.path.join(cfg['single_unit_spike_sorting_dir'], subject,
                            recording_date)
    if os.path.exists(data_dir):
        channel_results = []

        for array_idx in range(len(cfg['arrays'])):
            array = cfg['arrays'][array_idx]
            print(array)

            array_data_dir = os.path.join(data_dir, 'array_%d' % array_idx)

            if os.path.exists(array_data_dir):
                export_path = os.path.join(array_data_dir, 'figures')
                if not os.path.exists(export_path):
                    os.makedirs(export_path)

                for chan_grp in range(cfg['n_channels_per_array']):
                    print(chan_grp)

                    dataio = DataIO(array_data_dir, ch_grp=chan_grp)
                    dataio.datasource.bit_to_microVolt = 0.195
                    catalogueconstructor = CatalogueConstructor(
                        dataio=dataio, chan_grp=chan_grp)
                    catalogueconstructor.refresh_colors()
                    catalogue = dataio.load_catalogue(chan_grp=chan_grp)

                    channel_result = {
                        'array': array,
                        'channel': chan_grp,
                        'init_waveforms': '',
                        'clean_waveforms': '',
                        'noise': '',
                        'init_clusters': '',
                        'merge_clusters': [],
                        'final_clusters': [],
                        'all_clusters': ''
                    }

                    clusters = catalogue['clusters']

                    cluster_labels = clusters['cluster_label']
                    cell_labels = clusters['cell_label']

                    channel_result['init_waveforms'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_init_waveforms.png' % chan_grp)
                    channel_result['clean_waveforms'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_clean_waveforms.png' % chan_grp)
                    channel_result['noise'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_noise.png' % chan_grp)
                    channel_result['init_clusters'] = os.path.join(
                        'array_%d' % array_idx, 'figures',
                        'chan_%d_init_clusters.png' % chan_grp)

                    merge_files = glob.glob(
                        os.path.join(export_path,
                                     'chan_%d_merge_*.png' % chan_grp))
                    for merge_file in merge_files:
                        [path, file] = os.path.split(merge_file)
                        channel_result['merge_clusters'].append(
                            os.path.join('array_%d' % array_idx, 'figures',
                                         file))

                    for cluster_label in cluster_labels:
                        fig = plot_cluster_summary(dataio, catalogue, chan_grp,
                                                   cluster_label)
                        fname = 'chan_%d_cluster_%d.png' % (chan_grp,
                                                            cluster_label)
                        fig.savefig(os.path.join(export_path, fname))
                        fig.clf()
                        plt.close()
                        channel_result['final_clusters'].append(
                            os.path.join('array_%d' % array_idx, 'figures',
                                         fname))

                    fig = plot_clusters_summary(dataio, catalogueconstructor,
                                                chan_grp)
                    fname = 'chan_%d_clusters.png' % chan_grp
                    fig.savefig(os.path.join(export_path, fname))
                    fig.clf()
                    plt.close()
                    channel_result['all_clusters'] = os.path.join(
                        'array_%d' % array_idx, 'figures', fname)

                    channel_results.append(channel_result)

        env = Environment(loader=FileSystemLoader(cfg['template_dir']))
        template = env.get_template('spike_sorting_results_template.html')
        template_output = template.render(subject=subject,
                                          recording_date=recording_date,
                                          channel_results=channel_results)

        out_filename = os.path.join(data_dir, 'spike_sorting_report.html')
        with open(out_filename, 'w') as fh:
            fh.write(template_output)

        copyfile(os.path.join(cfg['template_dir'], 'style.css'),
                 os.path.join(data_dir, 'style.css'))
Ejemplo n.º 37
0
def test_derivative_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    deriv_sigs = derivative_signals(sigs)
    deriv_sigs[3.14:3.22].plot()
Ejemplo n.º 38
0
def test_normalize_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    normed_sigs = normalize_signals(sigs)
    normed_sigs[3.14:3.22].plot()
Ejemplo n.º 39
0
def test_DataIO():
    
    
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames, **params)
    
    #with geometry
    channels = list(range(14))
    channel_groups = {0:{'channels':range(14), 'geometry' : { c: [0, i] for i, c in enumerate(channels) }}}
    dataio.set_channel_groups(channel_groups)
    
    #with no geometry
    channel_groups = {0:{'channels':range(4)}}
    dataio.set_channel_groups(channel_groups)
    
    # add one group
    dataio.add_one_channel_group(channels=range(4,8), chan_grp=5)
    
    
    channel_groups = {0:{'channels':range(14)}}
    dataio.set_channel_groups(channel_groups)
    
    for seg_num in range(dataio.nb_segment):
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
            #~ print(seg_num, i_stop, sigs_chunk.shape)
    
    
    #reopen existing
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
    
    #~ exit()
    
    for seg_num in range(dataio.nb_segment):
        #~ print('seg_num', seg_num)
        for i_stop, sigs_chunk in dataio.iter_over_chunk(seg_num=seg_num, chunksize=1024):
            assert sigs_chunk.shape[0] == 1024
            assert sigs_chunk.shape[1] == 14
Ejemplo n.º 40
0
def test_DataIO_probes():
    # initialze dataio
    if os.path.exists('test_DataIO'):
        shutil.rmtree('test_DataIO')
        
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)


    localdir, filenames, params = download_dataset(name='olfactory_bulb')
    dataio.set_data_source(type='RawData', filenames=filenames,  **params)
    
    probe_filename = 'neuronexus/A4x8-5mm-100-400-413-A32.prb'
    dataio.download_probe(probe_filename)
    dataio.download_probe('neuronexus/A4x8-5mm-100-400-413-A32')
    
    #~ print(dataio.channel_groups)
    #~ print(dataio.channels)
    #~ print(dataio.info['probe_filename'])
    
    assert dataio.nb_channel(0) == 8
    assert probe_filename.split('/')[-1] == dataio.info['probe_filename']
    
    dataio = DataIO(dirname='test_DataIO')
    print(dataio)
Ejemplo n.º 41
0
def make_animation():
    """
    Good example between 1.272 1.302
    because collision
    """

    dataio = DataIO(dirname=dirname)
    catalogue = dataio.load_catalogue(chan_grp=0)

    clusters = catalogue['clusters']

    sr = dataio.sample_rate

    # also a good one a  11.356 - 11.366

    t1, t2 = 1.272, 1.295
    i1, i2 = int(t1 * sr), int(t2 * sr)

    spikes = dataio.get_spikes()
    spike_times = spikes['index'] / sr
    keep = (spike_times >= t1) & (spike_times <= t2)

    spikes = spikes[keep]
    print(spikes)

    sigs = dataio.get_signals_chunk(i_start=i1,
                                    i_stop=i2,
                                    signal_type='processed')
    sigs = sigs.copy()
    times = np.arange(sigs.shape[0]) / dataio.sample_rate

    def plot_spread_sigs(sigs, ax, ratioY=0.02, **kargs):
        #spread signals
        sigs2 = sigs * ratioY
        sigs2 += np.arange(0, len(channels))[np.newaxis, :]
        ax.plot(times, sigs2, **kargs)

        ax.set_ylim(-0.5, len(channels) - .5)
        ax.set_xticks([])
        ax.set_yticks([])

    residuals = sigs.copy()

    local_spikes = spikes.copy()
    local_spikes['index'] -= i1

    #~ fig, ax = plt.subplots()
    #~ plot_spread_sigs(sigs, ax, color='k')

    num_fig = 0

    fig_pred, ax_predictions = plt.subplots()
    ax_predictions.set_title('All detected templates from catalogue')

    fig, ax = plt.subplots()
    plot_spread_sigs(residuals, ax, color='k', lw=2)
    ax.set_title('Initial filtered signals with spikes')

    fig.savefig('../img/peeler_animation_sigs.png')

    fig.savefig('png/fig{}.png'.format(num_fig))
    num_fig += 1

    for i in range(local_spikes.size):
        label = local_spikes['cluster_label'][i]

        color = clusters[clusters['cluster_label'] == label]['color'][0]
        color = int32_to_rgba(color, mode='float')

        pred = make_prediction_signals(local_spikes[i:i + 1], 'float32',
                                       (i2 - i1, len(channels)), catalogue)

        fig, ax = plt.subplots()
        plot_spread_sigs(residuals, ax, color='k', lw=2)
        plot_spread_sigs(pred, ax, color=color, lw=1.5)
        ax.set_title('Dected spike label {}'.format(label))

        fig.savefig('png/fig{}.png'.format(num_fig))
        num_fig += 1

        residuals -= pred

        plot_spread_sigs(pred, ax_predictions, color=color, lw=1.5)

        fig, ax = plt.subplots()
        plot_spread_sigs(residuals, ax, color='k', lw=2)
        plot_spread_sigs(pred, ax, color=color, lw=1, ls='--')
        ax.set_title('New residual after substraction')

        fig.savefig('png/fig{}.png'.format(num_fig))
        num_fig += 1

    fig_pred.savefig('png/fig{}.png'.format(num_fig))
    num_fig += 1
Ejemplo n.º 42
0
def export_spikes(dirname):
    dataio = DataIO(dirname=dirname)
    dataio.export_spikes(dirname,formats='csv')
    dataio.export_spikes(dirname, formats='mat')
Ejemplo n.º 43
0
def compute_array_catalogue(array_idx, preprocess_dir, subject, recording_date,
                            data_files, cluster_merge_threshold):
    # If data exists for this array
    if os.path.exists(
            os.path.join(preprocess_dir, 'channel_group_%d' % array_idx,
                         'catalogue_constructor')):
        output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'],
                                  subject, recording_date,
                                  'array_%d' % array_idx)
        if os.path.exists(output_dir):
            # remove is already exists
            shutil.rmtree(output_dir)
        # Compute total duration (want to use all data for clustering)
        data_file_names = []
        for seg in range(len(data_files)):
            data_file_names.append(
                os.path.join(preprocess_dir, 'channel_group_%d' % array_idx,
                             'segment_%d' % seg, 'processed_signals.raw'))

        dataio = DataIO(dirname=output_dir)
        dataio.set_data_source(type='RawData',
                               filenames=data_file_names,
                               dtype='float32',
                               sample_rate=cfg['intan_srate'],
                               total_channel=cfg['n_channels_per_array'])
        dataio.datasource.bit_to_microVolt = 0.195
        for ch_grp in range(cfg['n_channels_per_array']):
            dataio.add_one_channel_group(channels=[ch_grp], chan_grp=ch_grp)

        total_duration = np.sum([x['duration'] for x in data_files])

        figure_out_dir = os.path.join(output_dir, 'figures')
        os.mkdir(figure_out_dir)
        for ch_grp in range(cfg['n_channels_per_array']):
            print(ch_grp)
            cc = CatalogueConstructor(dataio=DataIO(dirname=output_dir,
                                                    ch_grp=ch_grp),
                                      chan_grp=ch_grp)

            fullchain_kargs = {
                'duration': total_duration,
                'preprocessor': {
                    'highpass_freq': None,
                    'lowpass_freq': None,
                    'smooth_size': 0,
                    'common_ref_removal': False,
                    'chunksize': 32768,
                    'lostfront_chunksize': 0,
                    'signalpreprocessor_engine': 'numpy',
                },
                'peak_detector': {
                    'peakdetector_engine': 'numpy',
                    'peak_sign': '-',
                    'relative_threshold': 2.,
                    'peak_span': 0.0002,
                },
                'noise_snippet': {
                    'nb_snippet': 300,
                },
                'extract_waveforms': {
                    'n_left': -20,
                    'n_right': 30,
                    'mode': 'all',
                    'nb_max': 2000000,
                    'align_waveform': False,
                },
                'clean_waveforms': {
                    'alien_value_threshold': 100.,
                },
            }
            feat_method = 'pca_by_channel'
            feat_kargs = {'n_components_by_channel': 5}
            clust_method = 'sawchaincut'
            clust_kargs = {
                'max_loop': 1000,
                'nb_min': 20,
                'break_nb_remain': 30,
                'kde_bandwith': 0.01,
                'auto_merge_threshold': 2.,
                'print_debug': False
                # 'max_loop': 1000,
                # 'nb_min': 20,
                # 'break_nb_remain': 30,
                # 'kde_bandwith': 0.01,
                # 'auto_merge_threshold': cluster_merge_threshold,
                # 'print_debug': False
            }

            p = {}
            p.update(fullchain_kargs['preprocessor'])
            p.update(fullchain_kargs['peak_detector'])
            cc.set_preprocessor_params(**p)

            noise_duration = min(
                10., fullchain_kargs['duration'],
                dataio.get_segment_length(seg_num=0) / dataio.sample_rate *
                .99)
            # ~ print('noise_duration', noise_duration)
            t1 = time.perf_counter()
            cc.estimate_signals_noise(seg_num=0, duration=noise_duration)
            t2 = time.perf_counter()
            print('estimate_signals_noise', t2 - t1)

            t1 = time.perf_counter()
            cc.run_signalprocessor(duration=fullchain_kargs['duration'])
            t2 = time.perf_counter()
            print('run_signalprocessor', t2 - t1)

            t1 = time.perf_counter()
            cc.extract_some_waveforms(**fullchain_kargs['extract_waveforms'])
            t2 = time.perf_counter()
            print('extract_some_waveforms', t2 - t1)

            fname = 'chan_%d_init_waveforms.png' % ch_grp
            fig = plot_waveforms(np.squeeze(cc.some_waveforms).T)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            t1 = time.perf_counter()
            # ~ duration = d['duration'] if d['limit_duration'] else None
            # ~ d['clean_waveforms']
            cc.clean_waveforms(**fullchain_kargs['clean_waveforms'])
            t2 = time.perf_counter()
            print('clean_waveforms', t2 - t1)

            fname = 'chan_%d_clean_waveforms.png' % ch_grp
            fig = plot_waveforms(np.squeeze(cc.some_waveforms).T)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            # ~ t1 = time.perf_counter()
            # ~ n_left, n_right = cc.find_good_limits(mad_threshold = 1.1,)
            # ~ t2 = time.perf_counter()
            # ~ print('find_good_limits', t2-t1)

            t1 = time.perf_counter()
            cc.extract_some_noise(**fullchain_kargs['noise_snippet'])
            t2 = time.perf_counter()
            print('extract_some_noise', t2 - t1)

            # Plot noise
            fname = 'chan_%d_noise.png' % ch_grp
            fig = plot_noise(cc)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            t1 = time.perf_counter()
            cc.extract_some_features(method=feat_method, **feat_kargs)
            t2 = time.perf_counter()
            print('project', t2 - t1)

            t1 = time.perf_counter()
            cc.find_clusters(method=clust_method, **clust_kargs)
            t2 = time.perf_counter()
            print('find_clusters', t2 - t1)

            # Remove empty clusters
            cc.trash_small_cluster(n=0)

            if cc.centroids_median is None:
                cc.compute_all_centroid()

            # order cluster by waveforms rms
            cc.order_clusters(by='waveforms_rms')

            fname = 'chan_%d_init_clusters.png' % ch_grp
            cluster_labels = cc.clusters['cluster_label']
            fig = plot_cluster_waveforms(cc, cluster_labels)
            fig.savefig(os.path.join(figure_out_dir, fname))
            fig.clf()
            plt.close()

            # save the catalogue
            cc.make_catalogue_for_peeler()

            gc.collect()
Ejemplo n.º 44
0
def run_compare_catalogues(subject, date, similarity_threshold=0.7):
    bin_min = 0
    bin_max = 100
    bin_size = 1.
    bins = np.arange(bin_min, bin_max, bin_size)

    if os.path.exists(
            os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date)):
        # Get previous days to compare to
        sorted_dates, sorted_files = read_previous_sorts(subject)

        new_date = datetime.strptime(date, '%d.%m.%y')

        # Results for report
        channel_results = []

        for array_idx in range(len(cfg['arrays'])):
            new_output_dir = os.path.join(cfg['single_unit_spike_sorting_dir'],
                                          subject, date,
                                          'array_%d' % array_idx)

            if os.path.exists(new_output_dir):
                # Create directory for plots
                plot_output_dir = os.path.join(new_output_dir, 'figures',
                                               'catalogue_comparison')
                if not os.path.exists(plot_output_dir):
                    os.mkdir(plot_output_dir)

                for ch_grp in range(cfg['n_channels_per_array']):
                    channel_result = {
                        'array': cfg['arrays'][array_idx],
                        'channel': ch_grp,
                        'merged': [],
                        'unmerged': [],
                        'final': ''
                    }

                    # load catalogue for this channel
                    new_dataio = DataIO(dirname=new_output_dir, ch_grp=ch_grp)
                    catalogueconstructor = CatalogueConstructor(
                        dataio=new_dataio, chan_grp=ch_grp)

                    # Waveform time range
                    time_range = range(
                        catalogueconstructor.info['waveform_extractor_params']
                        ['n_left'],
                        catalogueconstructor.info['waveform_extractor_params']
                        ['n_right'])

                    # refresh
                    if catalogueconstructor.centroids_median is None:
                        catalogueconstructor.compute_all_centroid()
                    catalogueconstructor.refresh_colors()

                    # cell labels and cluster waveforms for this day
                    nn_idx = np.where(
                        catalogueconstructor.clusters['cell_label'] > -1)[0]

                    # If there are any clusters for this day
                    if len(nn_idx) > 0:

                        new_cluster_labels, new_cell_labels, new_wfs, new_wfs_stds, new_isis = get_cluster_info(
                            catalogueconstructor, new_dataio, ch_grp, bins,
                            nn_idx)

                        # Load cell labels and waveforms for all previous days
                        all_old_cluster_labels = []
                        all_old_cell_labels = []
                        all_old_wfs = np.zeros((0, 50))
                        all_old_wfs_stds = np.zeros((0, 50))
                        all_old_isis = np.zeros((0, len(bins) - 1))
                        recent_sorted_files = copy.copy(sorted_files)
                        recent_sorted_dates = copy.copy(sorted_dates)
                        if len(recent_sorted_files) > 20:
                            recent_sorted_files = recent_sorted_files[-20:]
                            recent_sorted_dates = recent_sorted_dates[-20:]

                        for old_date, old_file in zip(recent_sorted_dates,
                                                      recent_sorted_files):
                            if old_date < new_date:
                                print('loading %s' % old_date)
                                old_output_dir = os.path.join(
                                    cfg['single_unit_spike_sorting_dir'],
                                    subject, old_file, 'array_%d' % array_idx)
                                if os.path.exists(old_output_dir):
                                    old_dataio = DataIO(
                                        dirname=old_output_dir,
                                        ch_grp=ch_grp,
                                        reload_data_source=False)

                                    old_catalogueconstructor = CatalogueConstructor(
                                        dataio=old_dataio,
                                        chan_grp=ch_grp,
                                        load_persistent_arrays=False)
                                    old_catalogueconstructor.arrays.load_if_exists(
                                        'clusters')
                                    old_catalogueconstructor.arrays.load_if_exists(
                                        'centroids_median')
                                    old_catalogueconstructor.arrays.load_if_exists(
                                        'centroids_std')

                                    if old_catalogueconstructor.centroids_median is None:
                                        old_catalogueconstructor.compute_all_centroid(
                                        )

                                    old_cell_labels = old_catalogueconstructor.clusters[
                                        'cell_label']
                                    to_include = np.where(
                                        np.bitwise_and(
                                            np.isin(
                                                old_cell_labels,
                                                all_old_cell_labels) == False,
                                            old_cell_labels > -1))[0]

                                    if len(to_include):
                                        old_cluster_labels, old_cell_labels, old_wfs, old_wfs_stds, old_isis = get_cluster_info(
                                            old_catalogueconstructor,
                                            old_dataio, ch_grp, bins,
                                            to_include)

                                        all_old_wfs = np.concatenate(
                                            (all_old_wfs,
                                             copy.deepcopy(old_wfs)))
                                        all_old_wfs_stds = np.concatenate(
                                            (all_old_wfs_stds,
                                             copy.deepcopy(old_wfs_stds)))
                                        all_old_cell_labels.extend(
                                            copy.deepcopy(old_cell_labels))
                                        all_old_cluster_labels.extend(
                                            copy.deepcopy(old_cluster_labels))
                                        all_old_isis = np.concatenate(
                                            (all_old_isis,
                                             copy.deepcopy(old_isis)))
                                    #old_dataio.arrays.detach_array('clusters', mmap_close=True)
                                    #old_dataio.arrays.detach_array('centroids_median', mmap_close=True)
                                    #old_dataio.arrays.detach_array('centroids_std', mmap_close=True)

                        if len(all_old_cell_labels):
                            # Compute cluster similarity
                            wfs = np.concatenate((new_wfs, all_old_wfs))
                            cluster_similarity = metrics.cosine_similarity_with_max(
                                wfs)
                            new_old_cluster_similarity = cluster_similarity[
                                0:new_wfs.shape[0], new_wfs.shape[0]:]

                            # Plot cluster similarity
                            fname = 'chan_%d_similarity.png' % ch_grp
                            plot_new_old_cluster_similarity(
                                all_old_cluster_labels, all_old_cell_labels,
                                new_cluster_labels, new_cell_labels,
                                new_old_cluster_similarity, plot_output_dir,
                                fname)
                            channel_result['similarity'] = os.path.join(
                                'array_%d' % array_idx, 'figures',
                                'catalogue_comparison', fname)

                            # Go through each cluster in current day
                            for new_cluster_idx in range(new_wfs.shape[0]):
                                # Find most similar cluster from previous days
                                most_similar = np.argmax(
                                    new_old_cluster_similarity[
                                        new_cluster_idx, :])

                                # Merge if similarity greater than threshold
                                similarity = new_old_cluster_similarity[
                                    new_cluster_idx, most_similar]
                                if similarity >= similarity_threshold:
                                    print(
                                        'relabeling unit %d-%d as unit %d-%d' %
                                        (ch_grp,
                                         new_cell_labels[new_cluster_idx],
                                         ch_grp,
                                         all_old_cell_labels[most_similar]))
                                    fname = 'chan_%d_merge_%d-%d.png' % (
                                        ch_grp,
                                        new_cell_labels[new_cluster_idx],
                                        all_old_cell_labels[most_similar])

                                    plot_cluster_merge(
                                        all_old_cell_labels[most_similar],
                                        all_old_cluster_labels[most_similar],
                                        all_old_isis[most_similar, :],
                                        all_old_wfs[most_similar, :],
                                        all_old_wfs_stds[most_similar, :],
                                        new_cell_labels[new_cluster_idx],
                                        new_cluster_labels[new_cluster_idx],
                                        new_isis[new_cluster_idx, :],
                                        new_wfs[new_cluster_idx, :],
                                        new_wfs_stds[new_cluster_idx, :],
                                        similarity, time_range, bins,
                                        plot_output_dir, fname)
                                    channel_result['merged'].append(
                                        os.path.join('array_%d' % array_idx,
                                                     'figures',
                                                     'catalogue_comparison',
                                                     fname))

                                    new_cell_labels[
                                        new_cluster_idx] = all_old_cell_labels[
                                            most_similar]
                                # Otherwise, add new cluster
                                else:
                                    new_label = np.max(all_old_cell_labels) + 1
                                    print('adding new unit %d-%d' %
                                          (ch_grp, new_label))
                                    all_old_cell_labels.append(new_label)
                                    new_cell_labels[
                                        new_cluster_idx] = new_label

                                    fname = 'chan_%d_nonmerge_%d.png' % (
                                        ch_grp,
                                        new_cell_labels[new_cluster_idx])

                                    plot_cluster_new(
                                        all_old_cluster_labels,
                                        all_old_cell_labels, all_old_isis,
                                        all_old_wfs, all_old_wfs_stds,
                                        new_cluster_labels[new_cluster_idx],
                                        new_cell_labels[new_cluster_idx],
                                        new_isis[new_cluster_idx, :],
                                        new_wfs[new_cluster_idx, :],
                                        new_wfs_stds[new_cluster_idx, :],
                                        new_old_cluster_similarity[
                                            new_cluster_idx:new_cluster_idx +
                                            1, :], time_range, bins,
                                        plot_output_dir, fname)
                                    channel_result['unmerged'].append(
                                        os.path.join('array_%d' % array_idx,
                                                     'figures',
                                                     'catalogue_comparison',
                                                     fname))

                            catalogueconstructor.clusters['cell_label'][
                                nn_idx] = new_cell_labels

                        # Merge clusters with same labels
                        cluster_removed = True
                        while cluster_removed:
                            for cell_label in catalogueconstructor.clusters[
                                    'cell_label']:
                                cluster_idx = np.where(
                                    catalogueconstructor.clusters['cell_label']
                                    == cell_label)[0]
                                if len(cluster_idx) > 1:
                                    cluster_label1 = catalogueconstructor.cluster_labels[
                                        cluster_idx[0]]
                                    cluster_label2 = catalogueconstructor.cluster_labels[
                                        cluster_idx[1]]

                                    print('auto_merge', cluster_label2, 'with',
                                          cluster_label1)
                                    mask = catalogueconstructor.all_peaks[
                                        'cluster_label'] == cluster_label2
                                    catalogueconstructor.all_peaks[
                                        'cluster_label'][mask] = cluster_label1
                                    catalogueconstructor.remove_one_cluster(
                                        cluster_label2)
                                    break
                                else:
                                    cluster_removed = False

                        catalogueconstructor.make_catalogue_for_peeler()

                        fig = plot_clusters_summary(new_dataio,
                                                    catalogueconstructor,
                                                    ch_grp)
                        fname = 'chan_%d_final_clusters.png' % ch_grp
                        fig.savefig(os.path.join(plot_output_dir, fname))
                        fig.clf()
                        plt.close()
                        channel_result['final'] = os.path.join(
                            'array_%d' % array_idx, 'figures',
                            'catalogue_comparison', fname)

                    channel_results.append(channel_result)

        env = Environment(loader=FileSystemLoader(cfg['template_dir']))
        template = env.get_template('spike_sorting_merge_template.html')
        template_output = template.render(subject=subject,
                                          recording_date=date,
                                          channel_results=channel_results)

        out_filename = os.path.join(cfg['single_unit_spike_sorting_dir'],
                                    subject, date,
                                    'spike_sorting_merge_report.html')
        with open(out_filename, 'w') as fh:
            fh.write(template_output)

        copyfile(
            os.path.join(cfg['template_dir'], 'style.css'),
            os.path.join(cfg['single_unit_spike_sorting_dir'], subject, date,
                         'style.css'))