Exemplo n.º 1
0
def test_peeler():
    dataio = DataIO(dirname = 'datatest')
    #~ dataio = DataIO(dirname = 'datatest_neo')
    
    sigs = dataio.get_signals(seg_num=0)
    
    #peak
    peakdetector = PeakDetector(sigs)
    peak_pos = peakdetector.detect_peaks(threshold=-4, peak_sign = '-', n_span = 5)
    
    #waveforms
    waveformextractor = WaveformExtractor(peakdetector, n_left=-30, n_right=50)
    limit_left, limit_right = waveformextractor.find_good_limits(mad_threshold = 1.1)
    #~ print(limit_left, limit_right)
    short_wf = waveformextractor.get_ajusted_waveforms()
    #~ print(short_wf.shape)
    
    #clustering
    clustering = Clustering(short_wf)
    features = clustering.project(method = 'pca', n_components = 4)
    clustering.find_clusters(8, order_clusters = True)
    catalogue = clustering.construct_catalogue()
    #~ clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    #~ clustering.merge_cluster(1, 2)
    catalogue = clustering.construct_catalogue()
    clustering.plot_catalogue(sameax = False)
    #~ clustering.plot_catalogue(sameax = True)
    
    
    #peeler
    signals = peakdetector.normed_sigs
    peeler = Peeler(signals, catalogue,  limit_left, limit_right,
                            threshold=-5., peak_sign = '-', n_span = 5)
    
    prediction0, residuals0 = peeler.peel()
    prediction1, residuals1 = peeler.peel()
    
    spiketrains = peeler.get_spiketrains()
    print(spiketrains)
    
    fig, axs = pyplot.subplots(nrows = 6, sharex = True)#, sharey = True)
    axs[0].plot(signals)
    axs[1].plot(prediction0) 
    axs[2].plot(residuals0)
    axs[3].plot(prediction1)
    axs[4].plot(residuals1)
    
    for i in range(5):
        axs[i].set_ylim(-25, 10)
    
    peeler.plot_spiketrains(ax = axs[5])