def test1(): #save to neo bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter( f_low = 200.) spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,) spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') spikesorter.PcaFeature(n_components = 4) spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-') spikesorter.SklearnKMeans(n_cluster = 3) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size rcg = spikesorter.populate_recordingchannelgroup() print for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size
def test1(): bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) #~ spikesorter.ButterworthFilter( f_low = 200.) #~ spikesorter.DerivativeFilter() spikesorter.SlidingMedianFilter(window_size = 50.*pq.ms, sliding_step = 25.*pq.ms, interpolation = 'spline') spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak', peak_span = 0.5*pq.ms) print spikesorter spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import FilteredBandSignal app = QApplication([ ]) w1 = FilteredBandSignal(spikesorter = spikesorter) w1.refresh() w1.show() app.exec_()
def test1(): bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.7, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter( f_low = 200.) #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'crossing', #~ consistent_across_channels = False, #~ consistent_across_segments = True, #~ ) #~ print spikesorter spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'peak', peak_span = 0.53*pq.ms) print spikesorter.detection_thresholds print spikesorter #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'crossing', ) #~ print spikesorter #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'peak', peak_span = 0.3*pq.ms ) #~ print spikesorter spikesorter.populate_recordingchannelgroup(with_waveforms = False) spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import FilteredBandSignal app = QApplication([ ]) w2 = FilteredBandSignal(spikesorter = spikesorter) w2.refresh() w2.show() app.exec_()
def setUp(self): bl = generate_block_for_sorting( nb_unit=3, duration=1. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] self.sps = SpikeSorter(rcg, initial_state='full_band_signal')
def test1(): bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.7, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter( f_low = 200.) spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'peak', peak_span = 0.4*pq.ms) spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude') spikesorter.PcaFeature(n_components = 4) spikesorter.SklearnGaussianMixtureEm(n_cluster = 5) spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import FilteredBandSignal, AverageWaveforms app = QApplication([ ]) w1 = AverageWaveforms(spikesorter = spikesorter) w1.refresh() w1.show() w2 = FilteredBandSignal(spikesorter = spikesorter) w2.refresh() w2.show() app.exec_()
def test2(): #save to db url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),) session = dbinfo.Session() bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade =True) #~ print oebl is bl.OEinstance oebl.save() id_bl = oebl.id for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter( f_low = 200.) spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,) spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') spikesorter.PcaFeature(n_components = 4) spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-') spikesorter.SklearnKMeans(n_cluster = 3) spikesorter.save_in_database(session, dbinfo) dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),) session = dbinfo.Session() oebl = Block.load(id_bl) rcg = oebl.recordingchannelgroups[0] for s, seg in enumerate(rcg.block.segments): print 's', s, len(seg.spiketrains) for u, unit in enumerate(rcg.units): print 'u',u, len(unit.spiketrains) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.to_neo().size
def test4(): # add a spike #save to neo bl = generate_block_for_sorting( nb_unit=6, duration=1. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] sps = spikesorter = SpikeSorter(rcg) print sps.spike_index_array[0].shape print sps.spike_waveforms.shape print sps.seg_spike_slices print sps.waveform_features print sps.spike_clusters sps.delete_one_cluster(0) print print sps.spike_index_array[0].shape print sps.spike_waveforms.shape print sps.seg_spike_slices print sps.waveform_features print sps.spike_clusters
class BasicTest(unittest.TestCase): def setUp(self): bl = generate_block_for_sorting(nb_unit = 3, duration = 1.*pq.s, noise_ratio = 0.2, nb_segment = 2,) rcg = bl.recordingchannelgroups[0] self.sps = SpikeSorter(rcg, initial_state='full_band_signal') def tearDown(self): pass def test_getattr_aliases(self): self.assertIs(self.sps.segs, self.sps.segments) self.assertRaises(AttributeError, getattr, self.sps, 'i_love_my_mother') def test_getattr_runstep(self): self.sps.ButterworthFilter( f_low = 200.) self.assertIsInstance(self.sps.history[-1]['methodInstance'], ButterworthFilter) def test_one_standart_pipeline(self): self.sps.ButterworthFilter( f_low = 200.) self.assertIsNotNone(self.sps.filtered_sigs) self.sps.MedianThresholdDetection(sign= '-', median_thresh = 6,) self.assertIsNotNone(self.sps.spike_index_array) self.sps.AlignWaveformOnDetection(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms) self.assertIsNotNone(self.sps.seg_spike_slices) self.assertIsNotNone(self.sps.spike_waveforms) self.assertIsNotNone(self.sps.left_sweep) self.assertIsNotNone(self.sps.right_sweep) self.sps.PcaFeature(n_components = 3) self.assertIsNotNone(self.sps.waveform_features) self.sps.SklearnGaussianMixtureEm(n_cluster = 12, n_iter = 500 ) self.assertIsNotNone(self.sps.spike_clusters) self.assertIsNotNone(self.sps.cluster_names) def test_apply_history_to_other(self): sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal') self.sps.apply_history_to_other(sps2)
def test1(): bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter( f_low = 200.) spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'crossing') #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak') spikesorter.AlignWaveformOnDetection(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude') #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer') #~ spikesorter.AlignWaveformOnCentralWaveform(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, ) #~ print spikesorter.spike_waveforms.shape #~ s0 = spikesorter.spike_waveforms.shape[0] #~ wf2 = spikesorter.spike_waveforms.reshape(s0, -1) #~ from matplotlib import pyplot #~ pyplot.plot(wf2[:-10, :].transpose()) #~ pyplot.show() print spikesorter spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import AverageWaveforms, AllWaveforms app = QApplication([ ]) w1 = AverageWaveforms(spikesorter = spikesorter) w1.refresh() w1.show() w2 = AllWaveforms(spikesorter = spikesorter) w2.refresh() w2.show() app.exec_()
class BasicTest(unittest.TestCase): def setUp(self): bl = generate_block_for_sorting( nb_unit=3, duration=1. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] self.sps = SpikeSorter(rcg, initial_state='full_band_signal') def tearDown(self): pass def test_getattr_aliases(self): self.assertIs(self.sps.segs, self.sps.segments) self.assertRaises(AttributeError, getattr, self.sps, 'i_love_my_mother') def test_getattr_runstep(self): self.sps.ButterworthFilter(f_low=200.) self.assertIsInstance(self.sps.history[-1]['methodInstance'], ButterworthFilter) def test_one_standart_pipeline(self): self.sps.ButterworthFilter(f_low=200.) self.assertIsNotNone(self.sps.filtered_sigs) self.sps.MedianThresholdDetection( sign='-', median_thresh=6, ) self.assertIsNotNone(self.sps.spike_index_array) self.sps.AlignWaveformOnDetection(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms) self.assertIsNotNone(self.sps.seg_spike_slices) self.assertIsNotNone(self.sps.spike_waveforms) self.assertIsNotNone(self.sps.left_sweep) self.assertIsNotNone(self.sps.right_sweep) self.sps.PcaFeature(n_components=3) self.assertIsNotNone(self.sps.waveform_features) self.sps.SklearnGaussianMixtureEm(n_cluster=12, n_iter=500) self.assertIsNotNone(self.sps.spike_clusters) self.assertIsNotNone(self.sps.cluster_names) def test_apply_history_to_other(self): sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal') self.sps.apply_history_to_other(sps2)
def test1(): bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) #~ spikesorter.ButterworthFilter( f_low = 200.) #~ spikesorter.DerivativeFilter() spikesorter.SlidingMedianFilter(window_size=50. * pq.ms, sliding_step=25. * pq.ms, interpolation='spline') spikesorter.RelativeThresholdDetection(sign='-', relative_thresh=4., noise_estimation='MAD', threshold_mode='peak', peak_span=0.5 * pq.ms) print spikesorter spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import FilteredBandSignal app = QApplication([]) w1 = FilteredBandSignal(spikesorter=spikesorter) w1.refresh() w1.show() app.exec_()
def test1(): #save to neo bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.MedianThresholdDetection( sign='-', median_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size rcg = spikesorter.populate_recordingchannelgroup() print for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size
def test1(): bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'crossing') spikesorter.RelativeThresholdDetection(sign='-', relative_thresh=4., noise_estimation='MAD', threshold_mode='peak') spikesorter.AlignWaveformOnDetection(left_sweep=1.5 * pq.ms, right_sweep=2.5 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.SklearnKMeans(n_cluster=3) #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude') #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer') spikesorter.AlignWaveformOnCentralWaveform( left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, shift_estimation_method='taylor order1', #~ shift_estimation_method = 'taylor order2', #~ shift_estimation_method ='optimize', #~ shift_method = 'spline', shift_method='sinc', max_iter=3) step = spikesorter.history[-1] instance = step['methodInstance'] fig = pyplot.figure() instance.plot_iterative_centers(fig, spikesorter) pyplot.show()
def test1(): bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.RelativeThresholdDetection(sign='-', relative_thresh=4., noise_estimation='MAD', threshold_mode='crossing') #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak') spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude') #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer') #~ spikesorter.AlignWaveformOnCentralWaveform(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, ) #~ print spikesorter.spike_waveforms.shape #~ s0 = spikesorter.spike_waveforms.shape[0] #~ wf2 = spikesorter.spike_waveforms.reshape(s0, -1) #~ from matplotlib import pyplot #~ pyplot.plot(wf2[:-10, :].transpose()) #~ pyplot.show() print spikesorter spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import AverageWaveforms, AllWaveforms app = QApplication([]) w1 = AverageWaveforms(spikesorter=spikesorter) w1.refresh() w1.show() w2 = AllWaveforms(spikesorter=spikesorter) w2.refresh() w2.show() app.exec_()
def test1(): bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.7, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'crossing', #~ consistent_across_channels = False, #~ consistent_across_segments = True, #~ ) #~ print spikesorter spikesorter.RelativeThresholdDetection(sign='-', relative_thresh=3.5, noise_estimation='MAD', threshold_mode='peak', peak_span=0.53 * pq.ms) print spikesorter.detection_thresholds print spikesorter #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'crossing', ) #~ print spikesorter #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'peak', peak_span = 0.3*pq.ms ) #~ print spikesorter spikesorter.populate_recordingchannelgroup(with_waveforms=False) spikesorter.check_display_attributes() from OpenElectrophy.gui.spikesorting import FilteredBandSignal app = QApplication([]) w2 = FilteredBandSignal(spikesorter=spikesorter) w2.refresh() w2.show() app.exec_()
def setUp(self): bl = generate_block_for_sorting(nb_unit = 3, duration = 1.*pq.s, noise_ratio = 0.2, nb_segment = 2,) rcg = bl.recordingchannelgroups[0] self.sps = SpikeSorter(rcg, initial_state='full_band_signal')
import sys sys.path.append('..') if __name__== '__main__': import quantities as pq from OpenElectrophy.spikesorting import (generate_block_for_sorting, SpikeSorter) # read or create datasets bl = generate_block_for_sorting(nb_unit = 6, duration = 5.*pq.s, noise_ratio = 0.2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) # display unit before sorting for u, unit in enumerate(rcg.units): print u, 'unit name', unit.name for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print ' in Segment', s, 'has SpikeTrain with ', sptr.size # Apply a chain spikesorter.ButterworthFilter( f_low = 200.) # equivalent to # spikesorter.run_step(ButterworthFilter, f_low = 200.) spikesorter.MedianThresholdDetection(sign= '-',median_thresh = 6) spikesorter.AlignWaveformOnDetection(left_sweep = 1*pq.ms ,right_sweep = 2*pq.ms) spikesorter.PcaFeature(n_components = 6)
if __name__ == '__main__': import quantities as pq from OpenElectrophy.spikesorting import (generate_block_for_sorting, SpikeSorter) from OpenElectrophy.gui.spikesorting import AverageWaveforms, FeaturesNDViewer, FilteredBandSignal # read or create datasets bl = generate_block_for_sorting( nb_unit=6, duration=5. * pq.s, noise_ratio=0.2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) # Apply a chain spikesorter.ButterworthFilter(f_low=200.) # equivalent to # spikesorter.run_step(ButterworthFilter, f_low = 200.) spikesorter.MedianThresholdDetection(sign='-', median_thresh=6) spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms) spikesorter.PcaFeature(n_components=6) spikesorter.SklearnGaussianMixtureEm(n_cluster=6, n_iter=200) from PyQt4.QtGui import QApplication app = QApplication([]) spikesorter.check_display_attributes()
def test_apply_history_to_other(self): sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal') self.sps.apply_history_to_other(sps2)
def test1(): # open DB and create new column url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype #Create acolumn on table SpikeTrain (maybe this could be store at Unit level) create_column_if_not_exists(dbinfo.metadata.tables['SpikeTrain'], 'detection_thresholds', np.ndarray) #test is creation is OK for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype # re open DB for sorting url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() print dbinfo.classes_by_name['SpikeTrain'] for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype neobl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) neorcg = neobl.recordingchannelgroups[0] bl = OEBase.from_neo(neobl, dbinfo.mapped_classes, cascade=True) bl.save() rcg = OEBase.from_neo(neorcg, dbinfo.mapped_classes, cascade=True) rcg_id = rcg.id # spike sorting chain spikesorter = SpikeSorter(neorcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.RelativeThresholdDetection( sign='-', relative_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) spikesorter.save_in_database(session, dbinfo) # Note that threshold is per channel/segment for unit in rcg.units: for sptr in unit.spiketrains: print 'writte threshold on spiketrain.id', sptr.id print spikesorter.detection_thresholds sptr.detection_thresholds = spikesorter.detection_thresholds session.commit() # re open DB for sorting for reading dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() rcg = RecordingChannelGroup.load(rcg_id) for unit in rcg.units: for sptr in unit.spiketrains: print 'read threshold on spiketrain.id', sptr.id print sptr.detection_thresholds
import sys sys.path.append("..") if __name__ == "__main__": import quantities as pq from OpenElectrophy.spikesorting import generate_block_for_sorting, SpikeSorter from OpenElectrophy.gui.spikesorting import AverageWaveforms, FeaturesNDViewer, FilteredBandSignal # read or create datasets bl = generate_block_for_sorting(nb_unit=6, duration=5.0 * pq.s, noise_ratio=0.2) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) # Apply a chain spikesorter.ButterworthFilter(f_low=200.0) # equivalent to # spikesorter.run_step(ButterworthFilter, f_low = 200.) spikesorter.MedianThresholdDetection(sign="-", median_thresh=6) spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms) spikesorter.PcaFeature(n_components=6) spikesorter.SklearnGaussianMixtureEm(n_cluster=6, n_iter=200) from PyQt4.QtGui import QApplication app = QApplication([]) spikesorter.check_display_attributes()
def test2(): #save to db url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True) #~ print oebl is bl.OEinstance oebl.save() id_bl = oebl.id for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.MedianThresholdDetection( sign='-', median_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) spikesorter.save_in_database(session, dbinfo) dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() oebl = Block.load(id_bl) rcg = oebl.recordingchannelgroups[0] for s, seg in enumerate(rcg.block.segments): print 's', s, len(seg.spiketrains) for u, unit in enumerate(rcg.units): print 'u', u, len(unit.spiketrains) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.to_neo().size
""" # DO NOT FORGET this IPython comand %gui qt from OpenElectrophy.spikesorting import SpikeSorter, generate_block_for_sorting import quantities as pq # generate dataset bl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg, initial_state='full_band_signal') # Apply a chain spikesorter.ButterworthFilter( f_low = 200.) spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,) spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') # display widget interactivively from OpenElectrophy.gui.spikesorting import AverageWaveforms w1 = AverageWaveforms(spikesorter = spikesorter) w1.show() # test another step methods
def test1(): # open DB and create new column url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),) session = dbinfo.Session() for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items(): print attrname, attrtype #Create acolumn on table SpikeTrain (maybe this could be store at Unit level) create_column_if_not_exists(dbinfo.metadata.tables['SpikeTrain'],'detection_thresholds', np.ndarray) #test is creation is OK for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items(): print attrname, attrtype # re open DB for sorting url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),) session = dbinfo.Session() print dbinfo.classes_by_name['SpikeTrain'] for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items(): print attrname, attrtype neobl = generate_block_for_sorting(nb_unit = 6, duration = 10.*pq.s, noise_ratio = 0.2, nb_segment = 2, ) neorcg = neobl.recordingchannelgroups[0] bl = OEBase.from_neo(neobl, dbinfo.mapped_classes, cascade =True) bl.save() rcg = OEBase.from_neo(neorcg, dbinfo.mapped_classes, cascade =True) rcg_id = rcg.id # spike sorting chain spikesorter = SpikeSorter(neorcg) spikesorter.ButterworthFilter( f_low = 200.) spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 6.,) spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') spikesorter.PcaFeature(n_components = 4) spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-') spikesorter.SklearnKMeans(n_cluster = 3) spikesorter.save_in_database(session, dbinfo) # Note that threshold is per channel/segment for unit in rcg.units: for sptr in unit.spiketrains: print 'writte threshold on spiketrain.id',sptr.id print spikesorter.detection_thresholds sptr.detection_thresholds = spikesorter.detection_thresholds session.commit() # re open DB for sorting for reading dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),) session = dbinfo.Session() rcg = RecordingChannelGroup.load(rcg_id) for unit in rcg.units: for sptr in unit.spiketrains: print 'read threshold on spiketrain.id',sptr.id print sptr.detection_thresholds
noise_ratio=0.2, nb_segment=2, ) #~ for s, seg in enumerate(bl.segments): #~ for k, sptr in enumerate(seg.spiketrains): #~ print 's', s, 'k', k, sptr.size print bl.recordingchannelgroups oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True) print oebl.recordingchannelgroups oebl.save() else: print 'exist : transform to neo' bl = oebl.to_neo(cascade=True) rcg = bl.recordingchannelgroups[0] sps = spikesorter = SpikeSorter(rcg) if True: #~ spikesorter.ButterworthFilter( f_low = 200.) #~ spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,) #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-') #~ spikesorter.PcaFeature(n_components = 4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=5) spikesorter.check_display_attributes()