コード例 #1
0
def test_auto_lostfront_chunksize():
    sigs, sample_rate = get_dataset(name='olfactory_bulb')

    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    chunksize = 1024
    nloop = sigs.shape[0] // chunksize
    sigs = sigs[:chunksize * nloop]

    params = {
        'common_ref_removal': False,
        'highpass_freq': 300.,
        'lowpass_freq': 4000.,
        'smooth_size': 0,
        'output_dtype': 'float32',
        'normalize': True,
        'lostfront_chunksize': None,
    }

    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **params)

    online_sig = run_online('numpy', sigs, sample_rate, chunksize, **params)
    min_size = online_sig.shape[0]
    offline_sig = offline_sig[chunksize:min_size]
    online_sig = online_sig[chunksize:min_size]

    residual = np.abs(
        online_sig.astype('float64') - offline_sig.astype('float64'))

    residual_normed = residual / np.mean(np.abs(offline_sig.astype('float64')))
    print('   max residual_normed', np.max(residual_normed))
コード例 #2
0
def test_get_dataset():
    data, sample_rate = get_dataset(name='locust')
    assert data.shape == (431548, 4)
    assert sample_rate == 15000.0
    
    data, sample_rate = get_dataset(name='olfactory_bulb')
    assert data.shape == (150000, 14)
    assert sample_rate == 10000.0
コード例 #3
0
def test_get_dataset():
    data, sample_rate = get_dataset(name='locust')
    assert data.shape == (431548, 4)
    assert sample_rate == 15000.0

    data, sample_rate = get_dataset(name='olfactory_bulb')
    assert data.shape == (150000, 14)
    assert sample_rate == 10000.0

    data, sample_rate = get_dataset(name='purkinje')

    data, sample_rate = get_dataset(name='striatum_rat')
コード例 #4
0
def run_online(engine, sigs, sample_rate, chunksize, **params):
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    nb_channel = sigs.shape[1]

    # precompute medians and mads
    params2 = dict(params)
    params2['normalize'] = False
    sigs_for_noise = offline_signal_preprocessor(sigs, sample_rate, **params2)
    medians = np.median(sigs_for_noise, axis=0)
    mads = np.median(np.abs(sigs_for_noise - medians), axis=0) * 1.4826
    params['signals_medians'] = medians
    params['signals_mads'] = mads

    SignalPreprocessorClass = signalpreprocessor_engines[engine]
    signalpreprocessor = SignalPreprocessorClass(sample_rate, nb_channel,
                                                 chunksize, sigs.dtype)
    signalpreprocessor.change_params(**params)

    nloop = sigs.shape[0] // chunksize

    all_online_sigs = []
    t1 = time.perf_counter()
    for i in range(nloop):
        #~ print(i)
        pos = (i + 1) * chunksize
        chunk = sigs[pos - chunksize:pos, :]
        pos2, preprocessed_chunk = signalpreprocessor.process_data(pos, chunk)
        if preprocessed_chunk is not None:
            #~ print(preprocessed_chunk)
            all_online_sigs.append(preprocessed_chunk)
    online_sig = np.concatenate(all_online_sigs)
    t2 = time.perf_counter()
    print(engine, 'process time', t2 - t1)

    return online_sig
コード例 #5
0
def explore_lostfront_chunksize():

    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    chunksize = 1024
    nloop = sigs.shape[0] // chunksize
    sigs = sigs[:chunksize * nloop]

    params = {
        'common_ref_removal': False,
        'highpass_freq': 300.,
        'lowpass_freq': 4000.,
        'smooth_size': 0,
        'output_dtype': 'float32',
        'normalize': True,
        #~ 'lostfront_chunksize': 150
    }

    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **params)

    lostfront_chunksizes = [
        int(sample_rate / params['highpass_freq']) * i for i in range(1, 5)
    ]

    online_sigs = {}
    for lostfront_chunksize in lostfront_chunksizes:
        print('lostfront_chunksize', lostfront_chunksize)
        params['lostfront_chunksize'] = lostfront_chunksize
        online_sigs[lostfront_chunksize] = run_online('numpy', sigs,
                                                      sample_rate, chunksize,
                                                      **params)

    # remove border for comparison
    min_size = min([
        online_sigs[lostfront_chunksize].shape[0]
        for lostfront_chunksize in lostfront_chunksizes
    ])
    offline_sig = offline_sig[chunksize:min_size]
    for lostfront_chunksize in lostfront_chunksizes:
        online_sig = online_sigs[lostfront_chunksize]
        online_sigs[lostfront_chunksize] = online_sig[chunksize:min_size]

    for lostfront_chunksize in lostfront_chunksizes:
        print('lostfront_chunksize', lostfront_chunksize)

        online_sig = online_sigs[lostfront_chunksize]
        residual = np.abs(
            online_sig.astype('float64') - offline_sig.astype('float64'))

        residual_normed = residual / np.mean(
            np.abs(offline_sig.astype('float64')))
        print('   max residual_normed', np.max(residual_normed))
コード例 #6
0
def get_normed_sigs(chunksize=None):
    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num

    if sigs.shape[0] % chunksize > 0:
        sigs = sigs[:-(sigs.shape[0] % chunksize), :]

    nb_channel = sigs.shape[1]
    #~ print('nb_channel', nb_channel)

    geometry = np.zeros((nb_channel, 2))
    geometry[:, 0] = np.arange(nb_channel) * 50  # um spacing

    # normalize sigs
    highpass_freq = 300.
    preprocess_params = dict(highpass_freq=highpass_freq,
                             common_ref_removal=True,
                             backward_chunksize=chunksize + chunksize // 4,
                             output_dtype='float32')
    normed_sigs = offline_signal_preprocessor(sigs, sample_rate,
                                              **preprocess_params)

    return sigs, sample_rate, normed_sigs, geometry
コード例 #7
0
def test_compare_offline_online_engines():
    #~ HAVE_PYOPENCL = True
    if HAVE_PYOPENCL:
        engines = ['numpy', 'opencl']
        #~ engines = [ 'opencl']
        #~ engines = ['numpy']
    else:
        engines = ['numpy']

    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num

    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)

    #params
    chunksize = 1024
    peak_sign = '-'
    relative_threshold = 8
    peak_span = 0.0009

    #~ print('n_span', n_span)
    nloop = sigs.shape[0] // chunksize
    sigs = sigs[:chunksize * nloop]

    print('sig duration', sigs.shape[0] / sample_rate)

    # normalize sigs
    highpass_freq = 300.
    preprocess_params = dict(highpass_freq=highpass_freq,
                             common_ref_removal=True,
                             backward_chunksize=chunksize + chunksize // 4,
                             output_dtype='float32')
    normed_sigs = offline_signal_preprocessor(sigs, sample_rate,
                                              **preprocess_params)

    for peak_sign in [
            '-',
            '+',
    ]:
        #~ for peak_sign in ['+', ]:
        #~ for peak_sign in ['-', ]:
        print()
        print('peak_sign', peak_sign)
        if peak_sign == '-':
            sigs = normed_sigs
        elif peak_sign == '+':
            sigs = -normed_sigs

        #~ print(sigs.shape)
        #~ print('nloop', nloop)

        t1 = time.perf_counter()
        offline_peaks, rectified_sum = offline_peak_detect(
            sigs,
            sample_rate,
            peak_sign=peak_sign,
            relative_threshold=relative_threshold,
            peak_span=peak_span)
        t2 = time.perf_counter()
        print('offline', 'process time', t2 - t1)
        #~ print(offline_peaks)

        online_peaks = {}
        for engine in engines:
            print(engine)
            EngineClass = peakdetector_engines[engine]
            #~ buffer_size = chunksize*4
            peakdetector_engine = EngineClass(sample_rate, nb_channel,
                                              chunksize, 'float32')

            peakdetector_engine.change_params(
                peak_sign=peak_sign,
                relative_threshold=relative_threshold,
                peak_span=peak_span)

            all_online_peaks = []
            t1 = time.perf_counter()
            for i in range(nloop):
                #~ print(i)
                pos = (i + 1) * chunksize
                chunk = sigs[pos - chunksize:pos, :]
                n_peaks, chunk_peaks = peakdetector_engine.process_data(
                    pos, chunk)
                if chunk_peaks is not None:
                    #~ all_online_peaks.append(chunk_peaks['index'])
                    all_online_peaks.append(chunk_peaks)
            online_peaks[engine] = np.concatenate(all_online_peaks)
            t2 = time.perf_counter()
            print(engine, 'process time', t2 - t1)

        # remove peaks on border for comparison
        offline_peaks = offline_peaks[(offline_peaks > chunksize) & (
            offline_peaks < sigs.shape[0] - chunksize)]
        for engine in engines:
            onlinepeaks = online_peaks[engine]
            onlinepeaks = onlinepeaks[(onlinepeaks > chunksize) & (
                onlinepeaks < sigs.shape[0] - chunksize)]
            online_peaks[engine] = onlinepeaks

        # compare
        for engine in engines:
            onlinepeaks = online_peaks[engine]
            assert offline_peaks.size == onlinepeaks.size, '{} nb_peak{} instead {}'.format(
                engine, offline_peaks.size, onlinepeaks.size)
            assert np.array_equal(offline_peaks, onlinepeaks)
コード例 #8
0
def test_compare_offline_online_engines():
    HAVE_PYOPENCL = True
    if HAVE_PYOPENCL:
        engines = ['numpy', 'opencl']
        #~ engines = [ 'opencl']
        #~ engines = ['numpy']
    else:
        engines = ['numpy']

    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num
    
    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)

    
    
    #params
    chunksize = 1024
    peak_sign = '-'
    relative_threshold = 8
    peak_span = 0.0009
    
    #~ print('n_span', n_span)
    nloop = sigs.shape[0]//chunksize
    sigs = sigs[:chunksize*nloop]
    
    print('sig duration', sigs.shape[0]/sample_rate)
    
    # normalize sigs
    highpass_freq = 300.
    preprocess_params = dict(
                highpass_freq=highpass_freq,
                common_ref_removal=True,
                backward_chunksize=chunksize+chunksize//4,
                output_dtype='float32')
    normed_sigs = offline_signal_preprocessor(sigs, sample_rate, **preprocess_params)
    
    
    
    for peak_sign in ['-', '+', ]:
    #~ for peak_sign in ['+', ]:
    #~ for peak_sign in ['-', ]:
        print()
        print('peak_sign', peak_sign)
        if peak_sign=='-':
            sigs = normed_sigs
        elif peak_sign=='+':
            sigs = -normed_sigs
        
        #~ print(sigs.shape)
        #~ print('nloop', nloop)
        
        
        t1 = time.perf_counter()
        offline_peaks, rectified_sum = offline_peak_detect(sigs, sample_rate, peak_sign=peak_sign, relative_threshold=relative_threshold, peak_span=peak_span)
        t2 = time.perf_counter()
        print('offline', 'process time', t2-t1)
        #~ print(offline_peaks)
        
        online_peaks = {}
        for engine in engines:
            print(engine)
            EngineClass = peakdetector_engines[engine]
            #~ buffer_size = chunksize*4
            peakdetector_engine = EngineClass(sample_rate, nb_channel, chunksize, 'float32')
            
            peakdetector_engine.change_params(peak_sign=peak_sign, relative_threshold=relative_threshold,
                            peak_span=peak_span)
            
            all_online_peaks = []
            t1 = time.perf_counter()
            for i in range(nloop):
                #~ print(i)
                pos = (i+1)*chunksize
                chunk = sigs[pos-chunksize:pos,:]
                n_peaks, chunk_peaks = peakdetector_engine.process_data(pos, chunk)
                if chunk_peaks is not None:
                    #~ all_online_peaks.append(chunk_peaks['index'])
                    all_online_peaks.append(chunk_peaks)
            online_peaks[engine] = np.concatenate(all_online_peaks)
            t2 = time.perf_counter()
            print(engine, 'process time', t2-t1)
        
        # remove peaks on border for comparison
        offline_peaks = offline_peaks[(offline_peaks>chunksize) & (offline_peaks<sigs.shape[0]-chunksize)]
        for engine in engines:
            onlinepeaks = online_peaks[engine]
            onlinepeaks = onlinepeaks[(onlinepeaks>chunksize) & (onlinepeaks<sigs.shape[0]-chunksize)]
            online_peaks[engine] = onlinepeaks

        # compare
        for engine in engines:
            onlinepeaks = online_peaks[engine]
            assert offline_peaks.size==onlinepeaks.size, '{} nb_peak{} instead {}'.format(engine,  offline_peaks.size, onlinepeaks.size)
            assert np.array_equal(offline_peaks, onlinepeaks)
コード例 #9
0
def test_compare_offline_online_engines():
    #~ HAVE_PYOPENCL = True
    if HAVE_PYOPENCL:
        engines = ['numpy', 'opencl']
        #~ engines = [ 'opencl']
        #~ engines = ['numpy']
    else:
        engines = ['numpy']

    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    print(sample_rate)
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num

    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)

    #params
    chunksize = 1024
    nloop = sigs.shape[0] // chunksize
    sigs = sigs[:chunksize * nloop]

    print('sig duration', sigs.shape[0] / sample_rate)

    highpass_freq = 500.
    #~ highpass_freq = 300.
    #~ highpass_freq = 100.
    #~ highpass_freq = None
    lowpass_freq = 4000.
    #~ lowpass_freq = None
    #~ smooth_kernel = True
    smooth_size = 0

    #~ lostfront_chunksize = int(sample_rate/highpass_freq)*4
    #~ print('lostfront_chunksize', lostfront_chunksize)
    #~ exit()

    params = {
        'common_ref_removal': False,
        'highpass_freq': highpass_freq,
        'lowpass_freq': lowpass_freq,
        'smooth_size': smooth_size,
        'output_dtype': 'float32',
        'normalize': True,
        'lostfront_chunksize': 150
    }

    t1 = time.perf_counter()
    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **params)
    t2 = time.perf_counter()
    print('offline', 'process time', t2 - t1)

    online_sigs = {}
    for engine in engines:
        online_sigs[engine] = run_online(engine, sigs, sample_rate, chunksize,
                                         **params)

    # remove border for comparison
    min_size = min([online_sigs[engine].shape[0] for engine in engines])
    offline_sig = offline_sig[chunksize:min_size]
    for engine in engines:
        online_sig = online_sigs[engine]
        online_sigs[engine] = online_sig[chunksize:min_size]

    # compare
    for engine in engines:
        print(engine)
        online_sig = online_sigs[engine]
        residual = np.abs(
            online_sig.astype('float64') - offline_sig.astype('float64'))

        print('max residual', np.max(residual))
        print('/', np.mean(np.abs(offline_sig.astype('float64'))))
        residual_normed = residual / np.mean(
            np.abs(offline_sig.astype('float64')))
        print('max residual_normed', np.max(residual_normed))

        # plot
        # fig, axs = pyplot.subplots(nrows=nb_channel, sharex=True)
        #~ fig, axs = pyplot.subplots(nrows=4, sharex=True)
        # for i in range(nb_channel):
        #~ for i in range(4):
        #~ ax = axs[i]
        #~ ax.plot(residual_normed[:, i], color = 'k')
        #~ ax.plot(residual[:, i], color = 'k')
        #~ ax.plot(offline_sig[:, i], color = 'g')
        #~ ax.plot(online_sig[:, i], color = 'r', ls='--')
        #~ for i in range(nloop):
        #~ ax.axvline(i*chunksize, color='k', alpha=0.4)
        #~ pyplot.show()

        print(np.max(residual_normed))
        #~ print(np.mean(np.abs(offline_sig.astype('float64'))))
        assert np.max(
            residual_normed) < 0.05, 'online differt from offline more than 5%'
コード例 #10
0
def test_compare_offline_online_engines():
    HAVE_PYOPENCL = True
    if HAVE_PYOPENCL:
        engines = ['numpy', 'opencl']
        #~ engines = [ 'opencl']
        #~ engines = ['numpy']
    else:
        engines = ['numpy']


    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num
    
    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)
    
    #params
    chunksize = 1024
    nloop = sigs.shape[0]//chunksize
    sigs = sigs[:chunksize*nloop]
    
    print('sig duration', sigs.shape[0]/sample_rate)
    
    highpass_freq = 300.
    params = {
                'common_ref_removal' : True,
                'highpass_freq': highpass_freq, 'output_dtype': 'float32',
                'normalize' : True,
                'backward_chunksize':chunksize+chunksize//4}
    
    t1 = time.perf_counter()
    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **params)
    t2 = time.perf_counter()
    print('offline', 'process time', t2-t1)
    
    # precompute medians and mads
    params2 = dict(params)
    params2['normalize'] = False
    sigs_for_noise = offline_signal_preprocessor(sigs, sample_rate, **params2)
    medians = np.median(sigs_for_noise, axis=0)
    mads = np.median(np.abs(sigs_for_noise-medians),axis=0)*1.4826
    params['signals_medians'] = medians
    params['signals_mads'] = mads
    
    
    online_sigs = {}
    for engine in engines:
        print(engine)
        SignalPreprocessorClass = signalpreprocessor_engines[engine]
        signalpreprocessor = SignalPreprocessorClass(sample_rate, nb_channel, chunksize, sigs.dtype)
        signalpreprocessor.change_params(**params)
        
        all_online_sigs = []
        t1 = time.perf_counter()
        for i in range(nloop):
            #~ print(i)
            pos = (i+1)*chunksize
            chunk = sigs[pos-chunksize:pos,:]
            pos2, preprocessed_chunk = signalpreprocessor.process_data(pos, chunk)
            if preprocessed_chunk is not None:
                #~ print(preprocessed_chunk)
                all_online_sigs.append(preprocessed_chunk)
        online_sigs[engine] = np.concatenate(all_online_sigs)
        t2 = time.perf_counter()
        print(engine, 'process time', t2-t1)
    
    # remove border for comparison
    min_size = min([online_sigs[engine].shape[0] for engine in engines]) 
    offline_sig = offline_sig[chunksize:min_size]
    for engine in engines:
        online_sig = online_sigs[engine]
        online_sigs[engine] = online_sig[chunksize:min_size]

    # compare
    for engine in engines:
        online_sig = online_sigs[engine]
        residual = np.abs((online_sig.astype('float64')-offline_sig.astype('float64'))/np.mean(np.abs(offline_sig.astype('float64'))))

        print(np.max(residual))
        #~ print(np.mean(np.abs(offline_sig.astype('float64'))))
        assert np.max(residual)<7e-5, 'online differt from offline'
コード例 #11
0
def test_compare_offline_online_engines():

    sigs, sample_rate = get_dataset()
    #~ sigs = sigs[:, [0]]
    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)

    #params
    chunksize = 1024
    nloop = sigs.shape[0]//chunksize
    sigs = sigs[:chunksize*nloop]
    highpass_freq = 300.
    preprocess_params = dict(
                highpass_freq=highpass_freq,
                common_ref_removal=True,
                backward_chunksize=chunksize+chunksize//4,
                output_dtype='float32')
    
    peak_params = dict(peak_sign='-',
                                    relative_threshold=8,
                                    peak_span = 0.0009)
    
    waveforms_params = dict(n_left=-20, n_right=30)
    
    n_left = -20
    n_right = 30
    
    t1 = time.perf_counter()
    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **preprocess_params)
    offline_peaks, rectified_sum = offline_peak_detect(offline_sig, sample_rate, **peak_params)
    keep = (offline_peaks>chunksize) & (offline_peaks<sigs.shape[0]-chunksize)
    offline_peaks = offline_peaks[keep]
    offline_waveforms = cut_full(offline_sig, offline_peaks+n_left, n_right-n_left)
    print(offline_waveforms.shape)
    t2 = time.perf_counter()
    print('offline', 'process time', t2-t1)
    
    # precompute medians and mads
    params2 = dict(preprocess_params)
    params2['normalize'] = False
    sigs_for_noise = offline_signal_preprocessor(sigs, sample_rate, **params2)
    medians = np.median(sigs_for_noise, axis=0)
    mads = np.median(np.abs(sigs_for_noise-medians),axis=0)*1.4826
    preprocess_params['signals_medians'] = medians
    preprocess_params['signals_mads'] = mads
    #
    
    
    signalpreprocessor = SignalPreprocessor_Numpy(sample_rate, nb_channel, chunksize, sigs.dtype)
    signalpreprocessor.change_params(**preprocess_params)
    
    peakdetector = PeakDetectorEngine_Numpy(sample_rate, nb_channel, chunksize, 'float32')
    peakdetector.change_params(**peak_params)
            
    waveformextractor = OnlineWaveformExtractor(nb_channel, chunksize)
    waveformextractor.change_params(**waveforms_params)
    

    all_online_peak = []
    all_online_waveforms = []
    
    t1 = time.perf_counter()
    for i in range(nloop):
        #~ print()
        pos = (i+1)*chunksize
        #~ print('loop', i, 'pos', pos-chunksize, pos)
        
        chunk = sigs[pos-chunksize:pos,:]
        
        pos2, preprocessed_chunk = signalpreprocessor.process_data(pos, chunk)
        if preprocessed_chunk is  None:
            continue
        
        #~ print('pos2', pos)
        
        n_peaks, chunk_peaks = peakdetector.process_data(pos2, preprocessed_chunk)
        if chunk_peaks is  None:
            continue
        
        for peak_pos, chunk_waveforms in waveformextractor.new_peaks(pos2, preprocessed_chunk, chunk_peaks):
            #~ print(peak_pos, chunk_waveforms.shape)
            all_online_peak.append(peak_pos)
            all_online_waveforms.append(chunk_waveforms)
            
    t2 = time.perf_counter()
    print('online process time', t2-t1)
    
    online_peaks = np.concatenate(all_online_peak, axis=0)
    online_waveforms = np.concatenate(all_online_waveforms, axis=0)
    
    keep = (online_peaks>chunksize) & (online_peaks<sigs.shape[0]-chunksize)
    online_peaks = online_peaks[keep]
    online_waveforms = online_waveforms[keep]
    
    assert np.array_equal(offline_peaks, online_peaks)
    

    residual = np.abs((online_waveforms.astype('float64')-offline_waveforms.astype('float64'))/np.mean(np.abs(offline_waveforms.astype('float64'))))

    print(np.max(residual))
    #~ print(np.mean(np.abs(offline_sig.astype('float64'))))
    assert np.max(residual)<5e-5, 'online differt from offline'

    
    
    ind_error_max = np.argmax(np.max(residual.reshape(residual.shape[0], -1), axis=1))
    #~ print(ind_error_max)
    
    offline_wf = offline_waveforms[ind_error_max, : , :]
    online_wf = online_waveforms[ind_error_max, : , :]
    #~ print(online_wf.shape)
    
    fig, ax = pyplot.subplots()
    ax.plot(offline_wf.flatten(), color='g')
    ax.plot(online_wf.flatten(), color='r', ls='--')
    
    
    fig, ax = pyplot.subplots()
    wf2 = offline_waveforms.reshape(offline_waveforms.shape[0], -1)
    ax.plot(np.median(wf2, axis=0), color='g')
    wf3 = online_waveforms.reshape(offline_waveforms.shape[0], -1)
    ax.plot(np.median(wf3, axis=0), color='r', ls='--')
    
    
    pyplot.show()
コード例 #12
0
def test_compare_offline_online_engines():
    HAVE_PYOPENCL = True
    if HAVE_PYOPENCL:
        engines = ['numpy', 'opencl']
        #~ engines = [ 'opencl']
        #~ engines = ['numpy']
    else:
        engines = ['numpy']


    # get sigs
    sigs, sample_rate = get_dataset(name='olfactory_bulb')
    #~ sigs = np.tile(sigs, (1, 20)) #for testing large channels num
    
    nb_channel = sigs.shape[1]
    print('nb_channel', nb_channel)
    
    #params
    chunksize = 1024
    nloop = sigs.shape[0]//chunksize
    sigs = sigs[:chunksize*nloop]
    
    print('sig duration', sigs.shape[0]/sample_rate)
    
    highpass_freq = 300.
    highpass_freq = None
    lowpass_freq = 4000.
    #~ lowpass_freq = None
    #~ smooth_kernel = True
    smooth_size = 0
    
    params = {
                'common_ref_removal' : False,
                'highpass_freq': highpass_freq,
                'lowpass_freq': lowpass_freq,
                'smooth_size':smooth_size,
                'output_dtype': 'float32',
                'normalize' : True,
                'lostfront_chunksize': 128}
    
    t1 = time.perf_counter()
    offline_sig = offline_signal_preprocessor(sigs, sample_rate, **params)
    t2 = time.perf_counter()
    print('offline', 'process time', t2-t1)
    
    # precompute medians and mads
    params2 = dict(params)
    params2['normalize'] = False
    sigs_for_noise = offline_signal_preprocessor(sigs, sample_rate, **params2)
    medians = np.median(sigs_for_noise, axis=0)
    mads = np.median(np.abs(sigs_for_noise-medians),axis=0)*1.4826
    params['signals_medians'] = medians
    params['signals_mads'] = mads
    
    
    online_sigs = {}
    for engine in engines:
        print(engine)
        SignalPreprocessorClass = signalpreprocessor_engines[engine]
        signalpreprocessor = SignalPreprocessorClass(sample_rate, nb_channel, chunksize, sigs.dtype)
        signalpreprocessor.change_params(**params)
        
        all_online_sigs = []
        t1 = time.perf_counter()
        for i in range(nloop):
            #~ print(i)
            pos = (i+1)*chunksize
            chunk = sigs[pos-chunksize:pos,:]
            pos2, preprocessed_chunk = signalpreprocessor.process_data(pos, chunk)
            if preprocessed_chunk is not None:
                #~ print(preprocessed_chunk)
                all_online_sigs.append(preprocessed_chunk)
        online_sigs[engine] = np.concatenate(all_online_sigs)
        t2 = time.perf_counter()
        print(engine, 'process time', t2-t1)
    
    # remove border for comparison
    min_size = min([online_sigs[engine].shape[0] for engine in engines]) 
    offline_sig = offline_sig[chunksize:min_size]
    for engine in engines:
        online_sig = online_sigs[engine]
        online_sigs[engine] = online_sig[chunksize:min_size]

    # compare
    for engine in engines:
        print(engine)
        online_sig = online_sigs[engine]
        residual = np.abs((online_sig.astype('float64')-offline_sig.astype('float64'))/np.mean(np.abs(offline_sig.astype('float64'))))


        # plot
        #~ # fig, axs = pyplot.subplots(nrows=nb_channel, sharex=True)
        #~ fig, axs = pyplot.subplots(nrows=4, sharex=True)
        #~ # for i in range(nb_channel):
        #~ for i in range(4):
            #~ ax = axs[i]
            #~ ax.plot(residual[:, i], color = 'k')
            #~ ax.plot(offline_sig[:, i], color = 'g')
            #~ ax.plot(online_sig[:, i], color = 'r', ls='--')
            #~ for i in range(nloop):
                #~ ax.axvline(i*chunksize, color='k', alpha=0.4)
        #~ pyplot.show()

        print(np.max(residual))
        #~ print(np.mean(np.abs(offline_sig.astype('float64'))))
        assert np.max(residual)<7e-5, 'online differt from offline'