def test1():
    #save  to neo
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.2,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter( f_low = 200.)
    spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,)
    spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')
    spikesorter.PcaFeature(n_components = 4)
    spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-')
    spikesorter.SklearnKMeans(n_cluster = 3)


    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size


    rcg = spikesorter.populate_recordingchannelgroup()
    
    print 
    
    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size
Example #2
0
def test1():
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.2,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    #~ spikesorter.ButterworthFilter( f_low = 200.)
    #~ spikesorter.DerivativeFilter()
    spikesorter.SlidingMedianFilter(window_size =  50.*pq.ms,
                             sliding_step =  25.*pq.ms, interpolation = 'spline')
    
    
    spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak',  peak_span = 0.5*pq.ms)
    
    
    print spikesorter

    spikesorter.check_display_attributes()
    
    from OpenElectrophy.gui.spikesorting import FilteredBandSignal

    app = QApplication([ ])
    w1 = FilteredBandSignal(spikesorter = spikesorter)
    w1.refresh()
    w1.show()
    app.exec_()
Example #3
0
def test1():
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.7,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter( f_low = 200.)
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'crossing',
                        #~ consistent_across_channels = False,
                        #~ consistent_across_segments = True,                                                                                
                                                                                #~ )
    #~ print spikesorter
    spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'peak', peak_span = 0.53*pq.ms)
    print spikesorter.detection_thresholds
    print spikesorter
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'crossing', )
    #~ print spikesorter
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'peak', peak_span = 0.3*pq.ms )
    #~ print spikesorter
    spikesorter.populate_recordingchannelgroup(with_waveforms = False)
    

    
    spikesorter.check_display_attributes()
    from OpenElectrophy.gui.spikesorting import FilteredBandSignal
    app = QApplication([ ])
    w2 = FilteredBandSignal(spikesorter = spikesorter)
    w2.refresh()
    w2.show()
    app.exec_()
 def setUp(self):
     bl = generate_block_for_sorting(
         nb_unit=3,
         duration=1. * pq.s,
         noise_ratio=0.2,
         nb_segment=2,
     )
     rcg = bl.recordingchannelgroups[0]
     self.sps = SpikeSorter(rcg, initial_state='full_band_signal')
def test1():
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.7,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter( f_low = 200.)
    spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'peak', peak_span = 0.4*pq.ms)
    spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude')
    spikesorter.PcaFeature(n_components = 4)
    spikesorter.SklearnGaussianMixtureEm(n_cluster = 5)
    
    

    
    spikesorter.check_display_attributes()
    from OpenElectrophy.gui.spikesorting import FilteredBandSignal, AverageWaveforms
    app = QApplication([ ])
    w1 = AverageWaveforms(spikesorter = spikesorter)
    w1.refresh()
    w1.show()
    
    w2 = FilteredBandSignal(spikesorter = spikesorter)
    w2.refresh()
    w2.show()
    app.exec_()
def test2():
    #save  to db
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),)
    session = dbinfo.Session()
    
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.2,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]
    oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade =True)
    #~ print oebl is  bl.OEinstance
    oebl.save()
    id_bl = oebl.id

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.size
    
    
    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter( f_low = 200.)
    spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,)
    spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')
    spikesorter.PcaFeature(n_components = 4)
    spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-')
    spikesorter.SklearnKMeans(n_cluster = 3)


    spikesorter.save_in_database(session, dbinfo)
    
    dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),)
    session = dbinfo.Session()
    oebl = Block.load(id_bl)
    
    rcg = oebl.recordingchannelgroups[0]
    for s, seg in enumerate(rcg.block.segments):
        print 's', s, len(seg.spiketrains)
    
    for u, unit in enumerate(rcg.units):
        print 'u',u,  len(unit.spiketrains)
    
    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[s], sptr.to_neo().size
Example #7
0
def test4():
    # add a spike
    #save  to neo
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=1. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    sps = spikesorter = SpikeSorter(rcg)

    print sps.spike_index_array[0].shape
    print sps.spike_waveforms.shape
    print sps.seg_spike_slices
    print sps.waveform_features
    print sps.spike_clusters

    sps.delete_one_cluster(0)
    print

    print sps.spike_index_array[0].shape
    print sps.spike_waveforms.shape
    print sps.seg_spike_slices
    print sps.waveform_features
    print sps.spike_clusters
class BasicTest(unittest.TestCase):
    def setUp(self):
        bl = generate_block_for_sorting(nb_unit = 3, duration = 1.*pq.s,
                                                    noise_ratio = 0.2, nb_segment = 2,)
        rcg = bl.recordingchannelgroups[0]
        self.sps = SpikeSorter(rcg, initial_state='full_band_signal')
        
    def tearDown(self):
        pass
    
        
    def test_getattr_aliases(self):
        self.assertIs(self.sps.segs, self.sps.segments)
        self.assertRaises(AttributeError, getattr, self.sps, 'i_love_my_mother')
    
    def test_getattr_runstep(self):
        self.sps.ButterworthFilter( f_low = 200.)
        self.assertIsInstance(self.sps.history[-1]['methodInstance'], ButterworthFilter)
    
    def test_one_standart_pipeline(self):
        self.sps.ButterworthFilter( f_low = 200.)
        self.assertIsNotNone(self.sps.filtered_sigs)

        self.sps.MedianThresholdDetection(sign= '-', median_thresh = 6,)
        self.assertIsNotNone(self.sps.spike_index_array)

        self.sps.AlignWaveformOnDetection(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms)
        self.assertIsNotNone(self.sps.seg_spike_slices)
        self.assertIsNotNone(self.sps.spike_waveforms)
        self.assertIsNotNone(self.sps.left_sweep)
        self.assertIsNotNone(self.sps.right_sweep)
        
        self.sps.PcaFeature(n_components = 3)
        self.assertIsNotNone(self.sps.waveform_features)

        self.sps.SklearnGaussianMixtureEm(n_cluster = 12, n_iter = 500 )
        self.assertIsNotNone(self.sps.spike_clusters)
        self.assertIsNotNone(self.sps.cluster_names)
    
    def test_apply_history_to_other(self):
        sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal')
        self.sps.apply_history_to_other(sps2)
Example #9
0
def test1():
    bl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.2,
                                                        nb_segment = 2,
                                                        )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter( f_low = 200.)
    spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'crossing')
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak')
    
    spikesorter.AlignWaveformOnDetection(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude')
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer')
    #~ spikesorter.AlignWaveformOnCentralWaveform(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, )
    
    #~ print spikesorter.spike_waveforms.shape
    #~ s0 = spikesorter.spike_waveforms.shape[0]
    #~ wf2 = spikesorter.spike_waveforms.reshape(s0, -1)
    #~ from matplotlib import pyplot
    #~ pyplot.plot(wf2[:-10, :].transpose())
    #~ pyplot.show()
    
    
    
    
    print spikesorter

    spikesorter.check_display_attributes()
    from OpenElectrophy.gui.spikesorting import AverageWaveforms, AllWaveforms

    app = QApplication([ ])
    w1 = AverageWaveforms(spikesorter = spikesorter)
    w1.refresh()
    w1.show()
    w2 = AllWaveforms(spikesorter = spikesorter)
    w2.refresh()
    w2.show()
    app.exec_()
class BasicTest(unittest.TestCase):
    def setUp(self):
        bl = generate_block_for_sorting(
            nb_unit=3,
            duration=1. * pq.s,
            noise_ratio=0.2,
            nb_segment=2,
        )
        rcg = bl.recordingchannelgroups[0]
        self.sps = SpikeSorter(rcg, initial_state='full_band_signal')

    def tearDown(self):
        pass

    def test_getattr_aliases(self):
        self.assertIs(self.sps.segs, self.sps.segments)
        self.assertRaises(AttributeError, getattr, self.sps,
                          'i_love_my_mother')

    def test_getattr_runstep(self):
        self.sps.ButterworthFilter(f_low=200.)
        self.assertIsInstance(self.sps.history[-1]['methodInstance'],
                              ButterworthFilter)

    def test_one_standart_pipeline(self):
        self.sps.ButterworthFilter(f_low=200.)
        self.assertIsNotNone(self.sps.filtered_sigs)

        self.sps.MedianThresholdDetection(
            sign='-',
            median_thresh=6,
        )
        self.assertIsNotNone(self.sps.spike_index_array)

        self.sps.AlignWaveformOnDetection(left_sweep=1 * pq.ms,
                                          right_sweep=2 * pq.ms)
        self.assertIsNotNone(self.sps.seg_spike_slices)
        self.assertIsNotNone(self.sps.spike_waveforms)
        self.assertIsNotNone(self.sps.left_sweep)
        self.assertIsNotNone(self.sps.right_sweep)

        self.sps.PcaFeature(n_components=3)
        self.assertIsNotNone(self.sps.waveform_features)

        self.sps.SklearnGaussianMixtureEm(n_cluster=12, n_iter=500)
        self.assertIsNotNone(self.sps.spike_clusters)
        self.assertIsNotNone(self.sps.cluster_names)

    def test_apply_history_to_other(self):
        sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal')
        self.sps.apply_history_to_other(sps2)
Example #11
0
def test1():
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    #~ spikesorter.ButterworthFilter( f_low = 200.)
    #~ spikesorter.DerivativeFilter()
    spikesorter.SlidingMedianFilter(window_size=50. * pq.ms,
                                    sliding_step=25. * pq.ms,
                                    interpolation='spline')

    spikesorter.RelativeThresholdDetection(sign='-',
                                           relative_thresh=4.,
                                           noise_estimation='MAD',
                                           threshold_mode='peak',
                                           peak_span=0.5 * pq.ms)

    print spikesorter

    spikesorter.check_display_attributes()

    from OpenElectrophy.gui.spikesorting import FilteredBandSignal

    app = QApplication([])
    w1 = FilteredBandSignal(spikesorter=spikesorter)
    w1.refresh()
    w1.show()
    app.exec_()
Example #12
0
def test1():
    #save  to neo
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.MedianThresholdDetection(
        sign='-',
        median_thresh=6.,
    )
    spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms,
                                    right_sweep=2 * pq.ms,
                                    sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=3)

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size

    rcg = spikesorter.populate_recordingchannelgroup()

    print

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size
def test1():
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'crossing')
    spikesorter.RelativeThresholdDetection(sign='-',
                                           relative_thresh=4.,
                                           noise_estimation='MAD',
                                           threshold_mode='peak')

    spikesorter.AlignWaveformOnDetection(left_sweep=1.5 * pq.ms,
                                         right_sweep=2.5 * pq.ms,
                                         sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.SklearnKMeans(n_cluster=3)

    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude')
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer')
    spikesorter.AlignWaveformOnCentralWaveform(
        left_sweep=1 * pq.ms,
        right_sweep=2 * pq.ms,
        shift_estimation_method='taylor order1',
        #~ shift_estimation_method = 'taylor order2',
        #~ shift_estimation_method ='optimize',
        #~ shift_method = 'spline',
        shift_method='sinc',
        max_iter=3)

    step = spikesorter.history[-1]
    instance = step['methodInstance']

    fig = pyplot.figure()
    instance.plot_iterative_centers(fig, spikesorter)

    pyplot.show()
Example #14
0
def test1():
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.RelativeThresholdDetection(sign='-',
                                           relative_thresh=4.,
                                           noise_estimation='MAD',
                                           threshold_mode='crossing')
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 4.,noise_estimation = 'MAD', threshold_mode = 'peak')

    spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms,
                                         right_sweep=2 * pq.ms,
                                         sign='-')
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'biggest_amplitude')
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-', peak_method = 'closer')
    #~ spikesorter.AlignWaveformOnCentralWaveform(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, )

    #~ print spikesorter.spike_waveforms.shape
    #~ s0 = spikesorter.spike_waveforms.shape[0]
    #~ wf2 = spikesorter.spike_waveforms.reshape(s0, -1)
    #~ from matplotlib import pyplot
    #~ pyplot.plot(wf2[:-10, :].transpose())
    #~ pyplot.show()

    print spikesorter

    spikesorter.check_display_attributes()
    from OpenElectrophy.gui.spikesorting import AverageWaveforms, AllWaveforms

    app = QApplication([])
    w1 = AverageWaveforms(spikesorter=spikesorter)
    w1.refresh()
    w1.show()
    w2 = AllWaveforms(spikesorter=spikesorter)
    w2.refresh()
    w2.show()
    app.exec_()
Example #15
0
def test1():
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.7,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'MAD', threshold_mode = 'crossing',
    #~ consistent_across_channels = False,
    #~ consistent_across_segments = True,
    #~ )
    #~ print spikesorter
    spikesorter.RelativeThresholdDetection(sign='-',
                                           relative_thresh=3.5,
                                           noise_estimation='MAD',
                                           threshold_mode='peak',
                                           peak_span=0.53 * pq.ms)
    print spikesorter.detection_thresholds
    print spikesorter
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'crossing', )
    #~ print spikesorter
    #~ spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 3.5,noise_estimation = 'STD', threshold_mode = 'peak', peak_span = 0.3*pq.ms )
    #~ print spikesorter
    spikesorter.populate_recordingchannelgroup(with_waveforms=False)

    spikesorter.check_display_attributes()
    from OpenElectrophy.gui.spikesorting import FilteredBandSignal
    app = QApplication([])
    w2 = FilteredBandSignal(spikesorter=spikesorter)
    w2.refresh()
    w2.show()
    app.exec_()
Example #16
0
 def setUp(self):
     bl = generate_block_for_sorting(nb_unit = 3, duration = 1.*pq.s,
                                                 noise_ratio = 0.2, nb_segment = 2,)
     rcg = bl.recordingchannelgroups[0]
     self.sps = SpikeSorter(rcg, initial_state='full_band_signal')
Example #17
0
import sys
sys.path.append('..')

if __name__== '__main__':

    import quantities as pq
    from OpenElectrophy.spikesorting import (generate_block_for_sorting, SpikeSorter)
        
    # read or create datasets
    bl = generate_block_for_sorting(nb_unit = 6,
                                duration = 5.*pq.s,
                                noise_ratio = 0.2,
                                )
    rcg = bl.recordingchannelgroups[0]
    spikesorter = SpikeSorter(rcg)
    
    # display unit before sorting
    for u, unit in enumerate(rcg.units):
        print u, 'unit name', unit.name
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print ' in Segment', s, 'has SpikeTrain with ', sptr.size

    # Apply a chain
    spikesorter.ButterworthFilter( f_low = 200.)
    # equivalent to
    # spikesorter.run_step(ButterworthFilter, f_low = 200.)
    spikesorter.MedianThresholdDetection(sign= '-',median_thresh = 6)
    spikesorter.AlignWaveformOnDetection(left_sweep = 1*pq.ms ,right_sweep = 2*pq.ms)
    spikesorter.PcaFeature(n_components = 6)
Example #18
0
if __name__ == '__main__':

    import quantities as pq
    from OpenElectrophy.spikesorting import (generate_block_for_sorting,
                                             SpikeSorter)

    from OpenElectrophy.gui.spikesorting import AverageWaveforms, FeaturesNDViewer, FilteredBandSignal

    # read or create datasets
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=5. * pq.s,
        noise_ratio=0.2,
    )
    rcg = bl.recordingchannelgroups[0]
    spikesorter = SpikeSorter(rcg)

    # Apply a chain
    spikesorter.ButterworthFilter(f_low=200.)
    # equivalent to
    # spikesorter.run_step(ButterworthFilter, f_low = 200.)
    spikesorter.MedianThresholdDetection(sign='-', median_thresh=6)
    spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms,
                                         right_sweep=2 * pq.ms)
    spikesorter.PcaFeature(n_components=6)
    spikesorter.SklearnGaussianMixtureEm(n_cluster=6, n_iter=200)

    from PyQt4.QtGui import QApplication
    app = QApplication([])

    spikesorter.check_display_attributes()
 def test_apply_history_to_other(self):
     sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal')
     self.sps.apply_history_to_other(sps2)
Example #20
0
def test1():
    # open DB and create new column
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()

    for attrname, attrtype in dbinfo.classes_by_name[
            'SpikeTrain'].usable_attributes.items():
        print attrname, attrtype

    #Create acolumn on table SpikeTrain (maybe this could be store at Unit level)
    create_column_if_not_exists(dbinfo.metadata.tables['SpikeTrain'],
                                'detection_thresholds', np.ndarray)

    #test is creation is OK
    for attrname, attrtype in dbinfo.classes_by_name[
            'SpikeTrain'].usable_attributes.items():
        print attrname, attrtype

    # re open DB for sorting
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()
    print dbinfo.classes_by_name['SpikeTrain']
    for attrname, attrtype in dbinfo.classes_by_name[
            'SpikeTrain'].usable_attributes.items():
        print attrname, attrtype

    neobl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    neorcg = neobl.recordingchannelgroups[0]
    bl = OEBase.from_neo(neobl, dbinfo.mapped_classes, cascade=True)
    bl.save()
    rcg = OEBase.from_neo(neorcg, dbinfo.mapped_classes, cascade=True)
    rcg_id = rcg.id

    # spike sorting chain
    spikesorter = SpikeSorter(neorcg)
    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.RelativeThresholdDetection(
        sign='-',
        relative_thresh=6.,
    )
    spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms,
                                    right_sweep=2 * pq.ms,
                                    sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=3)
    spikesorter.save_in_database(session, dbinfo)

    # Note that threshold is per channel/segment
    for unit in rcg.units:
        for sptr in unit.spiketrains:
            print 'writte threshold on spiketrain.id', sptr.id
            print spikesorter.detection_thresholds
            sptr.detection_thresholds = spikesorter.detection_thresholds
    session.commit()

    # re open DB for sorting for reading
    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()
    rcg = RecordingChannelGroup.load(rcg_id)
    for unit in rcg.units:
        for sptr in unit.spiketrains:
            print 'read threshold on spiketrain.id', sptr.id
            print sptr.detection_thresholds
import sys

sys.path.append("..")

if __name__ == "__main__":

    import quantities as pq
    from OpenElectrophy.spikesorting import generate_block_for_sorting, SpikeSorter

    from OpenElectrophy.gui.spikesorting import AverageWaveforms, FeaturesNDViewer, FilteredBandSignal

    # read or create datasets
    bl = generate_block_for_sorting(nb_unit=6, duration=5.0 * pq.s, noise_ratio=0.2)
    rcg = bl.recordingchannelgroups[0]
    spikesorter = SpikeSorter(rcg)

    # Apply a chain
    spikesorter.ButterworthFilter(f_low=200.0)
    # equivalent to
    # spikesorter.run_step(ButterworthFilter, f_low = 200.)
    spikesorter.MedianThresholdDetection(sign="-", median_thresh=6)
    spikesorter.AlignWaveformOnDetection(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms)
    spikesorter.PcaFeature(n_components=6)
    spikesorter.SklearnGaussianMixtureEm(n_cluster=6, n_iter=200)

    from PyQt4.QtGui import QApplication

    app = QApplication([])

    spikesorter.check_display_attributes()
Example #22
0
def test2():
    #save  to db
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()

    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]
    oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True)
    #~ print oebl is  bl.OEinstance
    oebl.save()
    id_bl = oebl.id

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.MedianThresholdDetection(
        sign='-',
        median_thresh=6.,
    )
    spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms,
                                    right_sweep=2 * pq.ms,
                                    sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=3)

    spikesorter.save_in_database(session, dbinfo)

    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()
    oebl = Block.load(id_bl)

    rcg = oebl.recordingchannelgroups[0]
    for s, seg in enumerate(rcg.block.segments):
        print 's', s, len(seg.spiketrains)

    for u, unit in enumerate(rcg.units):
        print 'u', u, len(unit.spiketrains)

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.to_neo().size
"""


# DO NOT FORGET this IPython comand
%gui qt

from OpenElectrophy.spikesorting import SpikeSorter, generate_block_for_sorting
import quantities as pq
# generate dataset
bl = generate_block_for_sorting(nb_unit = 6,
                                                    duration = 10.*pq.s,
                                                    noise_ratio = 0.2,
                                                    nb_segment = 2,
                                                    )
rcg = bl.recordingchannelgroups[0]
spikesorter = SpikeSorter(rcg, initial_state='full_band_signal')


# Apply a chain
spikesorter.ButterworthFilter( f_low = 200.)
spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,)
spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')


# display widget interactivively
from OpenElectrophy.gui.spikesorting import AverageWaveforms
w1 = AverageWaveforms(spikesorter = spikesorter)
w1.show()


# test another step methods
def test1():
    # open DB and create new column
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),)
    session = dbinfo.Session()
    

    for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items():
        print attrname, attrtype

    
    #Create acolumn on table SpikeTrain (maybe this could be store at Unit level)
    create_column_if_not_exists(dbinfo.metadata.tables['SpikeTrain'],'detection_thresholds', np.ndarray)
    
    #test is creation is OK
    for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items():
        print attrname, attrtype
    
    
    
    
    # re open DB for sorting
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),)
    session = dbinfo.Session()
    print dbinfo.classes_by_name['SpikeTrain']
    for attrname, attrtype in dbinfo.classes_by_name['SpikeTrain'].usable_attributes.items():
        print attrname, attrtype
    
    
    neobl = generate_block_for_sorting(nb_unit = 6,
                                                        duration = 10.*pq.s,
                                                        noise_ratio = 0.2,
                                                        nb_segment = 2,
                                                        )
    neorcg = neobl.recordingchannelgroups[0]
    bl = OEBase.from_neo(neobl, dbinfo.mapped_classes, cascade =True)
    bl.save()
    rcg = OEBase.from_neo(neorcg, dbinfo.mapped_classes, cascade =True)
    rcg_id = rcg.id
    
    # spike sorting chain
    spikesorter = SpikeSorter(neorcg)
    spikesorter.ButterworthFilter( f_low = 200.)
    spikesorter.RelativeThresholdDetection(sign= '-', relative_thresh = 6.,)
    spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')
    spikesorter.PcaFeature(n_components = 4)
    spikesorter.CombineFeature(use_peak = True, use_peak_to_valley = True, n_pca = 3, n_ica = 3, n_haar = 3, sign = '-')
    spikesorter.SklearnKMeans(n_cluster = 3)
    spikesorter.save_in_database(session, dbinfo)
    
    
    # Note that threshold is per channel/segment
    for unit in rcg.units:
        for sptr in unit.spiketrains:
            print 'writte threshold on spiketrain.id',sptr.id
            print spikesorter.detection_thresholds
            sptr.detection_thresholds = spikesorter.detection_thresholds
    session.commit()

    
    # re open DB for sorting for reading
    dbinfo = open_db(url = url, use_global_session = True, myglobals = globals(),)
    session = dbinfo.Session()
    rcg = RecordingChannelGroup.load(rcg_id)
    for unit in rcg.units:
        for sptr in unit.spiketrains:
            print 'read threshold on spiketrain.id',sptr.id
            print sptr.detection_thresholds
        noise_ratio=0.2,
        nb_segment=2,
    )
    #~ for s, seg in enumerate(bl.segments):
    #~ for k, sptr in enumerate(seg.spiketrains):
    #~ print 's', s, 'k', k, sptr.size
    print bl.recordingchannelgroups
    oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True)
    print oebl.recordingchannelgroups
    oebl.save()
else:
    print 'exist : transform to neo'
    bl = oebl.to_neo(cascade=True)
rcg = bl.recordingchannelgroups[0]

sps = spikesorter = SpikeSorter(rcg)

if True:
    #~ spikesorter.ButterworthFilter( f_low = 200.)
    #~ spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,)
    #~ spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')
    #~ spikesorter.PcaFeature(n_components = 4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=5)

spikesorter.check_display_attributes()