def test_good_events(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) #~ peak_pos = peak_pos[:100] #~ print(peak_pos) waveforms = extract_peak_waveforms(sigs, peak_pos, peak_pos, -30,50) keep = good_events(waveforms,upper_thr=5.,lower_thr=-5.) #~ print(keep) goods_wf = waveforms[keep] bads_wf = waveforms[~keep] fig, ax = pyplot.subplots() #~ goods_wf.transpose().plot(ax =ax, color = 'g', lw = .3) bads_wf.transpose().plot(ax =ax,color = 'r', lw = .3) med = waveforms.median(axis=0) mad = np.median(np.abs(waveforms-med),axis=0)*1.4826 limit1 = med+5*mad limit2 = med-5*mad med.plot(ax = ax, color = 'm', lw = 2) limit1.plot(ax = ax, color = 'm') limit2.plot(ax = ax, color = 'm')
def test_clustering(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) short_wf = waveformextractor.get_ajusted_waveforms() print(short_wf.shape) #clustering clustering = Clustering(short_wf) #PCA features = clustering.project(method = 'pca', n_components = 5) clustering.plot_projection(plot_density = True) #Kmean labels = clustering.find_clusters(7) clustering.plot_projection(plot_density = True) #ùake catalogue catalogue = clustering.construct_catalogue() clustering.plot_derivatives() clustering.plot_catalogue() clustering.merge_cluster(1,2) clustering.split_cluster(1, 2)
def test_rectify_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) retified_sigs = rectify_signals(normalize_signals(sigs), threshold = -4) fig, ax = pyplot.subplots() retified_sigs[3.14:3.22].plot(ax = ax) ax.set_ylim(-20, 10)
def test_peakdetector(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs, seg_num=0) peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) print(peakdetector.peak_pos.size) peakdetector.detect_peaks(threshold=-5, peak_sign = '-', n_span = 5) print(peakdetector.peak_pos.size)
def test_peeler(): dataio = DataIO(dirname = 'datatest') #~ dataio = DataIO(dirname = 'datatest_neo') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms() #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 4) clustering.find_clusters(8, order_clusters = True) catalogue = clustering.construct_catalogue() #~ clustering.plot_catalogue(sameax = False) #~ clustering.plot_catalogue(sameax = True) #~ clustering.merge_cluster(1, 2) catalogue = clustering.construct_catalogue() clustering.plot_catalogue(sameax = False) #~ clustering.plot_catalogue(sameax = True) #peeler signals = peakdetector.normed_sigs peeler = Peeler(signals, catalogue, limit_left, limit_right, threshold=-5., peak_sign = '-', n_span = 5) prediction0, residuals0 = peeler.peel() prediction1, residuals1 = peeler.peel() spiketrains = peeler.get_spiketrains() print(spiketrains) fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True) axs[0].plot(signals) axs[1].plot(prediction0) axs[2].plot(residuals0) axs[3].plot(prediction1) axs[4].plot(residuals1) for i in range(5): axs[i].set_ylim(-25, 10) peeler.plot_spiketrains(ax = axs[5])
def test_extract_noise_waveforms(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) waveforms = extract_noise_waveforms(sigs, peak_pos, -30,50, size = 500) print(waveforms.shape) fig, ax = pyplot.subplots() waveforms.median(axis=0).plot(ax =ax)
def test_filter(): dataio = DataIO(dirname = 'datatest_neo') sigs = dataio.get_signals(seg_num=0, signal_type = 'unfiltered') filter = SignalFilter(sigs, highpass_freq = 300.) filterred_sigs = filter.get_filtered_data() filter2 = SignalFilter(sigs, highpass_freq = 300., box_smooth = 3) filterred_sigs2 = filter2.get_filtered_data() fig, axs = pyplot.subplots(nrows=3, sharex = True) sigs[0:.5].plot(ax=axs[0]) filterred_sigs[0:.5].plot(ax=axs[1]) filterred_sigs2[0:.5].plot(ax=axs[2])
def test_extract_peak_waveforms(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 2) peak_index = sigs.index[peak_pos] waveforms = extract_peak_waveforms(sigs, peak_pos,peak_index, -30,50) print(waveforms.shape) fig, ax = pyplot.subplots() waveforms.median(axis=0).plot(ax =ax) normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -15,50) fig, ax = pyplot.subplots() normed_waveforms.median(axis=0).plot(ax =ax)
def test_waveform_extractor(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) print(limit_left, limit_right) long_wf = waveformextractor.long_waveforms short_wf = waveformextractor.get_ajusted_waveforms() assert long_wf.shape[1]>short_wf.shape[1] waveformextractor.plot_good_limit()
def test_peeler(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms(margin=2) #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 5) clustering.find_clusters(7) catalogue = clustering.construct_catalogue() clustering.plot_catalogue() #peeler signals = peakdetector.normed_sigs peeler = Peeler(signals, catalogue, limit_left, limit_right, threshold=-4, peak_sign = '-', n_span = 5) prediction0, residuals0 = peeler.peel() prediction1, residuals1 = peeler.peel() fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True) axs[0].plot(signals) axs[1].plot(prediction0) axs[2].plot(residuals0) axs[3].plot(prediction1) axs[4].plot(residuals1) colors = sns.color_palette('husl', len(catalogue)) spiketrains = peeler.get_spiketrains() i = 0 for k , pos in spiketrains.items(): axs[5].plot(pos, np.ones(pos.size)*k, ls = 'None', marker = '|', markeredgecolor = colors[i], markersize = 10, markeredgewidth = 2) i += 1 axs[5].set_ylim(0, len(catalogue))
def plot_interpolation(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) #peak peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) #waveforms waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50) limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1) #~ print(limit_left, limit_right) short_wf = waveformextractor.get_ajusted_waveforms(margin=2) #~ print(short_wf.shape) #clustering clustering = Clustering(short_wf) features = clustering.project(method = 'pca', n_components = 5) clustering.find_clusters(7) catalogue = clustering.construct_catalogue() k = list(catalogue.keys())[1] w0 = catalogue[k]['center'] w1 = catalogue[k]['centerD'] w2 = catalogue[k]['centerDD'] fig, ax = pyplot.subplots() t = np.arange(w0.size) colors = sns.color_palette('husl', 12) all = [] jitters = np.arange(-.5,.5,.1) for i, jitter in enumerate(jitters): pred = w0 + jitter*w1 + jitter**2/2.*w2 all.append(pred) ax.plot(t+jitter, pred, marker = 'o', label = str(jitter), color = colors[i], linestyle = 'None') ax.plot(t, w0, marker = '*', markersize = 4, label = 'w0', lw = 1, color = 'k') all = np.array(all) interpolated = all.transpose().flatten() t2 = np.arange(interpolated.size)/all.shape[0] + jitters[0] ax.plot(t2, interpolated, label = 'interp', lw = 1, color = 'm')
def test_detect_peak_method_span(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) normed_sigs = normalize_signals(sigs) retified_sigs = rectify_signals(normed_sigs, threshold = -4) peaks_pos = detect_peak_method_span(retified_sigs, peak_sign='-', n_span = 5) peaks_index = sigs.index[peaks_pos] fig, ax = pyplot.subplots() chunk = retified_sigs[3.14:3.22] chunk.plot(ax = ax) peaks_value = retified_sigs.loc[peaks_index] peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k') ax.set_ylim(-20, 10) fig, ax = pyplot.subplots() chunk = normed_sigs[3.14:3.22] chunk.plot(ax = ax) peaks_value = normed_sigs.loc[peaks_index] peaks_value[3.14:3.22].plot(marker = 'o', linestyle = 'None', ax = ax, color = 'k') ax.set_ylim(-20, 10)
def test_find_good_limits(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) peakdetector = PeakDetector(sigs) peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5) peak_index = sigs.index[peak_pos] normed_waveforms = extract_peak_waveforms(peakdetector.normed_sigs, peak_pos, peak_index, -25,50) normed_med = normed_waveforms.median(axis=0) normed_mad = np.median(np.abs(normed_waveforms-normed_med),axis=0)*1.4826 normed_mad = normed_mad.reshape(4,-1) fig, ax = pyplot.subplots() ax.plot(normed_mad.transpose()) ax.axhline(1.1) l1, l2 = find_good_limits(normed_mad) print(l1,l2) ax.axvline(l1) ax.axvline(l2)
def test_normalize_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) normed_sigs = normalize_signals(sigs) normed_sigs[3.14:3.22].plot()
def test_derivative_signals(): dataio = DataIO(dirname = 'datatest') sigs = dataio.get_signals(seg_num=0) deriv_sigs = derivative_signals(sigs) deriv_sigs[3.14:3.22].plot()