Ejemplo n.º 1
0
class BasicTest(unittest.TestCase):
    def setUp(self):
        bl = generate_block_for_sorting(
            nb_unit=3,
            duration=1. * pq.s,
            noise_ratio=0.2,
            nb_segment=2,
        )
        rcg = bl.recordingchannelgroups[0]
        self.sps = SpikeSorter(rcg, initial_state='full_band_signal')

    def tearDown(self):
        pass

    def test_getattr_aliases(self):
        self.assertIs(self.sps.segs, self.sps.segments)
        self.assertRaises(AttributeError, getattr, self.sps,
                          'i_love_my_mother')

    def test_getattr_runstep(self):
        self.sps.ButterworthFilter(f_low=200.)
        self.assertIsInstance(self.sps.history[-1]['methodInstance'],
                              ButterworthFilter)

    def test_one_standart_pipeline(self):
        self.sps.ButterworthFilter(f_low=200.)
        self.assertIsNotNone(self.sps.filtered_sigs)

        self.sps.MedianThresholdDetection(
            sign='-',
            median_thresh=6,
        )
        self.assertIsNotNone(self.sps.spike_index_array)

        self.sps.AlignWaveformOnDetection(left_sweep=1 * pq.ms,
                                          right_sweep=2 * pq.ms)
        self.assertIsNotNone(self.sps.seg_spike_slices)
        self.assertIsNotNone(self.sps.spike_waveforms)
        self.assertIsNotNone(self.sps.left_sweep)
        self.assertIsNotNone(self.sps.right_sweep)

        self.sps.PcaFeature(n_components=3)
        self.assertIsNotNone(self.sps.waveform_features)

        self.sps.SklearnGaussianMixtureEm(n_cluster=12, n_iter=500)
        self.assertIsNotNone(self.sps.spike_clusters)
        self.assertIsNotNone(self.sps.cluster_names)

    def test_apply_history_to_other(self):
        sps2 = SpikeSorter(self.sps.rcg, initial_state='full_band_signal')
        self.sps.apply_history_to_other(sps2)
Ejemplo n.º 2
0
def test1():
    #save  to neo
    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.MedianThresholdDetection(
        sign='-',
        median_thresh=6.,
    )
    spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms,
                                    right_sweep=2 * pq.ms,
                                    sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=3)

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size

    rcg = spikesorter.populate_recordingchannelgroup()

    print

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size
from OpenElectrophy.spikesorting import SpikeSorter, generate_block_for_sorting
import quantities as pq
# generate dataset
bl = generate_block_for_sorting(nb_unit = 6,
                                                    duration = 10.*pq.s,
                                                    noise_ratio = 0.2,
                                                    nb_segment = 2,
                                                    )
rcg = bl.recordingchannelgroups[0]
spikesorter = SpikeSorter(rcg, initial_state='full_band_signal')


# Apply a chain
spikesorter.ButterworthFilter( f_low = 200.)
spikesorter.MedianThresholdDetection(sign= '-', median_thresh = 6.,)
spikesorter.AlignWaveformOnPeak(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms, sign = '-')


# display widget interactivively
from OpenElectrophy.gui.spikesorting import AverageWaveforms
w1 = AverageWaveforms(spikesorter = spikesorter)
w1.show()


# test another step methods
spikesorter.AlignWaveformOnDetection(left_sweep = 1*pq.ms , right_sweep = 2*pq.ms)
w1.refresh()


Ejemplo n.º 4
0
def test2():
    #save  to db
    url = 'sqlite:///test_spikesorter.sqlite'
    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()

    bl = generate_block_for_sorting(
        nb_unit=6,
        duration=10. * pq.s,
        noise_ratio=0.2,
        nb_segment=2,
    )
    rcg = bl.recordingchannelgroups[0]
    oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True)
    #~ print oebl is  bl.OEinstance
    oebl.save()
    id_bl = oebl.id

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.size

    spikesorter = SpikeSorter(rcg)

    spikesorter.ButterworthFilter(f_low=200.)
    spikesorter.MedianThresholdDetection(
        sign='-',
        median_thresh=6.,
    )
    spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms,
                                    right_sweep=2 * pq.ms,
                                    sign='-')
    spikesorter.PcaFeature(n_components=4)
    spikesorter.CombineFeature(use_peak=True,
                               use_peak_to_valley=True,
                               n_pca=3,
                               n_ica=3,
                               n_haar=3,
                               sign='-')
    spikesorter.SklearnKMeans(n_cluster=3)

    spikesorter.save_in_database(session, dbinfo)

    dbinfo = open_db(
        url=url,
        use_global_session=True,
        myglobals=globals(),
    )
    session = dbinfo.Session()
    oebl = Block.load(id_bl)

    rcg = oebl.recordingchannelgroups[0]
    for s, seg in enumerate(rcg.block.segments):
        print 's', s, len(seg.spiketrains)

    for u, unit in enumerate(rcg.units):
        print 'u', u, len(unit.spiketrains)

    for u, unit in enumerate(rcg.units):
        for s, seg in enumerate(rcg.block.segments):
            sptr = seg.spiketrains[u]
            print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[
                s], sptr.to_neo().size