def test1(): #save to neo bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.MedianThresholdDetection( sign='-', median_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size rcg = spikesorter.populate_recordingchannelgroup() print for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size
def test1(): # open DB and create new column url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype #Create acolumn on table SpikeTrain (maybe this could be store at Unit level) create_column_if_not_exists(dbinfo.metadata.tables['SpikeTrain'], 'detection_thresholds', np.ndarray) #test is creation is OK for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype # re open DB for sorting url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() print dbinfo.classes_by_name['SpikeTrain'] for attrname, attrtype in dbinfo.classes_by_name[ 'SpikeTrain'].usable_attributes.items(): print attrname, attrtype neobl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) neorcg = neobl.recordingchannelgroups[0] bl = OEBase.from_neo(neobl, dbinfo.mapped_classes, cascade=True) bl.save() rcg = OEBase.from_neo(neorcg, dbinfo.mapped_classes, cascade=True) rcg_id = rcg.id # spike sorting chain spikesorter = SpikeSorter(neorcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.RelativeThresholdDetection( sign='-', relative_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) spikesorter.save_in_database(session, dbinfo) # Note that threshold is per channel/segment for unit in rcg.units: for sptr in unit.spiketrains: print 'writte threshold on spiketrain.id', sptr.id print spikesorter.detection_thresholds sptr.detection_thresholds = spikesorter.detection_thresholds session.commit() # re open DB for sorting for reading dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() rcg = RecordingChannelGroup.load(rcg_id) for unit in rcg.units: for sptr in unit.spiketrains: print 'read threshold on spiketrain.id', sptr.id print sptr.detection_thresholds
def test2(): #save to db url = 'sqlite:///test_spikesorter.sqlite' dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() bl = generate_block_for_sorting( nb_unit=6, duration=10. * pq.s, noise_ratio=0.2, nb_segment=2, ) rcg = bl.recordingchannelgroups[0] oebl = OEBase.from_neo(bl, dbinfo.mapped_classes, cascade=True) #~ print oebl is bl.OEinstance oebl.save() id_bl = oebl.id for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.size spikesorter = SpikeSorter(rcg) spikesorter.ButterworthFilter(f_low=200.) spikesorter.MedianThresholdDetection( sign='-', median_thresh=6., ) spikesorter.AlignWaveformOnPeak(left_sweep=1 * pq.ms, right_sweep=2 * pq.ms, sign='-') spikesorter.PcaFeature(n_components=4) spikesorter.CombineFeature(use_peak=True, use_peak_to_valley=True, n_pca=3, n_ica=3, n_haar=3, sign='-') spikesorter.SklearnKMeans(n_cluster=3) spikesorter.save_in_database(session, dbinfo) dbinfo = open_db( url=url, use_global_session=True, myglobals=globals(), ) session = dbinfo.Session() oebl = Block.load(id_bl) rcg = oebl.recordingchannelgroups[0] for s, seg in enumerate(rcg.block.segments): print 's', s, len(seg.spiketrains) for u, unit in enumerate(rcg.units): print 'u', u, len(unit.spiketrains) for u, unit in enumerate(rcg.units): for s, seg in enumerate(rcg.block.segments): sptr = seg.spiketrains[u] print 'u', u, 's', s, seg.spiketrains[u] is unit.spiketrains[ s], sptr.to_neo().size