def test_good_events():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    #~ peak_pos = peak_pos[:100]
    #~ print(peak_pos)
    
    waveforms = extract_peak_waveforms(sigs, peak_pos, peak_pos,  -30,50)
    keep = good_events(waveforms,upper_thr=5.,lower_thr=-5.)
    #~ print(keep)
    goods_wf = waveforms[keep]
    bads_wf = waveforms[~keep]


    fig, ax = pyplot.subplots()
    #~ goods_wf.transpose().plot(ax =ax, color = 'g', lw = .3)
    bads_wf.transpose().plot(ax =ax,color = 'r', lw = .3)
    
    
    med = waveforms.median(axis=0)
    mad = np.median(np.abs(waveforms-med),axis=0)*1.4826
    limit1 = med+5*mad
    limit2 = med-5*mad
    
    med.plot(ax = ax, color = 'm', lw = 2)
    limit1.plot(ax = ax, color = 'm')
    limit2.plot(ax = ax, color = 'm')
def test_clustering():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    short_wf = waveformextractor.get_ajusted_waveforms()
    print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    
    #PCA
    features = clustering.project(method = 'pca', n_components = 5)
    
    clustering.plot_projection(plot_density = True)
    
    #Kmean
    labels = clustering.find_clusters(7)
    clustering.plot_projection(plot_density = True)
    
    #ùake catalogue
    catalogue = clustering.construct_catalogue()
    clustering.plot_derivatives()
    clustering.plot_catalogue()


    clustering.merge_cluster(1,2)
    clustering.split_cluster(1, 2)
def test_rectify_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    retified_sigs = rectify_signals(normalize_signals(sigs), threshold = -4)
    
    fig, ax = pyplot.subplots()
    retified_sigs[3.14:3.22].plot(ax = ax)
    ax.set_ylim(-20, 10)
def test_peakdetector():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs, seg_num=0)
    peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    print(peakdetector.peak_pos.size)
    peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 5)
    print(peakdetector.peak_pos.size)
Exemple #5
0
def test_peeler():
    dataio = DataIO(dirname = 'datatest')
    #~ dataio = DataIO(dirname = 'datatest_neo')
    
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms()
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 4)
    clustering.find_clusters(8, order_clusters = True)
    catalogue = clustering.construct_catalogue()
    #~ clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    #~ clustering.merge_cluster(1, 2)
    catalogue = clustering.construct_catalogue()
    clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    
    #peeler
    signals = peakdetector.normed_sigs
    peeler = Peeler(signals, catalogue,  limit_left, limit_right,
                            threshold=-5., peak_sign = '-', n_span = 5)
    
    prediction0, residuals0 = peeler.peel()
    prediction1, residuals1 = peeler.peel()
    
    spiketrains = peeler.get_spiketrains()
    print(spiketrains)
    
    fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True)
    axs[0].plot(signals)
    axs[1].plot(prediction0) 
    axs[2].plot(residuals0)
    axs[3].plot(prediction1)
    axs[4].plot(residuals1)
    
    for i in range(5):
        axs[i].set_ylim(-25, 10)
    
    peeler.plot_spiketrains(ax = axs[5])
def test_extract_noise_waveforms():

    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    
    
    waveforms = extract_noise_waveforms(sigs, peak_pos, -30,50, size = 500)
    print(waveforms.shape)
    
    fig, ax = pyplot.subplots()
    waveforms.median(axis=0).plot(ax =ax)
Exemple #7
0
def test_filter():
    dataio = DataIO(dirname = 'datatest_neo')
    sigs = dataio.get_signals(seg_num=0, signal_type = 'unfiltered')
    
    filter =  SignalFilter(sigs, highpass_freq = 300.)
    filterred_sigs = filter.get_filtered_data()

    filter2 =  SignalFilter(sigs, highpass_freq = 300., box_smooth = 3)
    filterred_sigs2 = filter2.get_filtered_data()
    
    
    fig, axs = pyplot.subplots(nrows=3, sharex = True)
    sigs[0:.5].plot(ax=axs[0])
    filterred_sigs[0:.5].plot(ax=axs[1])
    filterred_sigs2[0:.5].plot(ax=axs[2])
def test_extract_peak_waveforms():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2)
    peak_index = sigs.index[peak_pos]
    
    waveforms = extract_peak_waveforms(sigs, peak_pos,peak_index,  -30,50)
    print(waveforms.shape)
    fig, ax = pyplot.subplots()
    waveforms.median(axis=0).plot(ax =ax)
    
    normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index,  -15,50)
    fig, ax = pyplot.subplots()
    normed_waveforms.median(axis=0).plot(ax =ax)
def test_waveform_extractor():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    peakdetector = PeakDetector(sigs)
    peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    print(limit_left, limit_right)
    
    long_wf = waveformextractor.long_waveforms
    short_wf = waveformextractor.get_ajusted_waveforms()
    
    assert long_wf.shape[1]>short_wf.shape[1]
    
    waveformextractor.plot_good_limit()
def test_peeler():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms(margin=2)
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 5)
    clustering.find_clusters(7)
    catalogue = clustering.construct_catalogue()
    
    clustering.plot_catalogue()
    
    #peeler
    signals = peakdetector.normed_sigs
    peeler = Peeler(signals, catalogue,  limit_left, limit_right,
                            threshold=-4, peak_sign = '-', n_span = 5)
    
    prediction0, residuals0 = peeler.peel()
    prediction1, residuals1 = peeler.peel()
    fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True)
    axs[0].plot(signals)
    axs[1].plot(prediction0) 
    axs[2].plot(residuals0)
    axs[3].plot(prediction1)
    axs[4].plot(residuals1)
    
    colors = sns.color_palette('husl', len(catalogue))
    spiketrains = peeler.get_spiketrains()
    i = 0
    for k , pos in spiketrains.items():
        axs[5].plot(pos, np.ones(pos.size)*k, ls = 'None', marker = '|',  markeredgecolor = colors[i], markersize = 10, markeredgewidth = 2)
        i += 1
    axs[5].set_ylim(0, len(catalogue))
def plot_interpolation():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms(margin=2)
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 5)
    clustering.find_clusters(7)
    catalogue = clustering.construct_catalogue()
    
    k = list(catalogue.keys())[1]
    w0 = catalogue[k]['center']
    w1 = catalogue[k]['centerD']
    w2 = catalogue[k]['centerDD']
    
    fig, ax = pyplot.subplots()
    t = np.arange(w0.size)
    
    colors = sns.color_palette('husl', 12)
    
    all = []
    jitters = np.arange(-.5,.5,.1)
    for i, jitter in enumerate(jitters):
        pred = w0 + jitter*w1 + jitter**2/2.*w2
        all.append(pred)
        ax.plot(t+jitter, pred, marker = 'o', label = str(jitter), color = colors[i], linestyle = 'None')
    ax.plot(t, w0, marker = '*', markersize = 4, label = 'w0', lw = 1, color = 'k')     

    all = np.array(all)
    interpolated = all.transpose().flatten()
    t2 = np.arange(interpolated.size)/all.shape[0] + jitters[0]
    ax.plot(t2, interpolated, label = 'interp', lw = 1, color = 'm')
def test_detect_peak_method_span():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    normed_sigs = normalize_signals(sigs)
    retified_sigs = rectify_signals(normed_sigs, threshold = -4)    
    peaks_pos = detect_peak_method_span(retified_sigs,  peak_sign='-', n_span = 5)
    peaks_index = sigs.index[peaks_pos]
    
    fig, ax = pyplot.subplots()
    chunk = retified_sigs[3.14:3.22]
    chunk.plot(ax = ax)
    peaks_value = retified_sigs.loc[peaks_index]
    peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k')
    ax.set_ylim(-20, 10)
    
    fig, ax = pyplot.subplots()
    chunk = normed_sigs[3.14:3.22]
    chunk.plot(ax = ax)
    peaks_value = normed_sigs.loc[peaks_index]
    peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k')
    ax.set_ylim(-20, 10)
def test_find_good_limits():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    
    
    
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    peak_index = sigs.index[peak_pos]
    
    normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -25,50)
    
    normed_med = normed_waveforms.median(axis=0)
    normed_mad = np.median(np.abs(normed_waveforms-normed_med),axis=0)*1.4826
    normed_mad = normed_mad.reshape(4,-1)
    
    fig, ax = pyplot.subplots()
    ax.plot(normed_mad.transpose())
    ax.axhline(1.1)
    
    l1, l2 = find_good_limits(normed_mad)
    print(l1,l2)
    ax.axvline(l1)
    ax.axvline(l2)
def test_normalize_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    normed_sigs = normalize_signals(sigs)
    normed_sigs[3.14:3.22].plot()
def test_derivative_signals():
    dataio = DataIO(dirname = 'datatest')
    sigs = dataio.get_signals(seg_num=0)
    deriv_sigs = derivative_signals(sigs)
    deriv_sigs[3.14:3.22].plot()