예제 #1
0
def test_plot_filter():
    """Test filter plotting."""
    l_freq, h_freq, sfreq = 2., 40., 1000.
    data = np.zeros(5000)
    freq = [0, 2, 40, 50, 500]
    gain = [0, 1, 1, 0, 0]
    h = create_filter(data, sfreq, l_freq, h_freq, fir_design='firwin2')
    plot_filter(h, sfreq)
    plt.close('all')
    plot_filter(h, sfreq, freq, gain)
    plt.close('all')
    iir = create_filter(data, sfreq, l_freq, h_freq, method='iir')
    plot_filter(iir, sfreq)
    plt.close('all')
    plot_filter(iir, sfreq, freq, gain)
    plt.close('all')
    iir_ba = create_filter(data,
                           sfreq,
                           l_freq,
                           h_freq,
                           method='iir',
                           iir_params=dict(output='ba'))
    plot_filter(iir_ba, sfreq, freq, gain)
    plt.close('all')
    plot_filter(h, sfreq, freq, gain, fscale='linear')
    plt.close('all')
예제 #2
0
def test_plot_filter():
    """Test filter plotting."""
    l_freq, h_freq, sfreq = 2., 40., 1000.
    data = np.zeros(5000)
    freq = [0, 2, 40, 50, 500]
    gain = [0, 1, 1, 0, 0]
    h = create_filter(data, sfreq, l_freq, h_freq, fir_design='firwin2')
    plot_filter(h, sfreq)
    plt.close('all')
    plot_filter(h, sfreq, freq, gain)
    plt.close('all')
    iir = create_filter(data, sfreq, l_freq, h_freq, method='iir')
    plot_filter(iir, sfreq)
    plt.close('all')
    plot_filter(iir, sfreq, freq, gain)
    plt.close('all')
    iir_ba = create_filter(data, sfreq, l_freq, h_freq, method='iir',
                           iir_params=dict(output='ba'))
    plot_filter(iir_ba, sfreq, freq, gain)
    plt.close('all')
    fig = plot_filter(h, sfreq, freq, gain, fscale='linear')
    assert len(fig.axes) == 3
    plt.close('all')
    fig = plot_filter(h, sfreq, freq, gain, fscale='linear',
                      plot=('time', 'delay'))
    assert len(fig.axes) == 2
    plt.close('all')
    fig = plot_filter(h, sfreq, freq, gain, fscale='linear',
                      plot=['magnitude', 'delay'])
    assert len(fig.axes) == 2
    plt.close('all')
    fig = plot_filter(h, sfreq, freq, gain, fscale='linear',
                      plot='magnitude')
    assert len(fig.axes) == 1
    plt.close('all')
    fig = plot_filter(h, sfreq, freq, gain, fscale='linear',
                      plot=('magnitude'))
    assert len(fig.axes) == 1
    plt.close('all')
    with pytest.raises(ValueError, match='Invalid value for the .plot'):
        plot_filter(h, sfreq, freq, gain, plot=('turtles'))
    _, axes = plt.subplots(1)
    fig = plot_filter(h, sfreq, freq, gain, plot=('magnitude'), axes=axes)
    assert len(fig.axes) == 1
    _, axes = plt.subplots(2)
    fig = plot_filter(h, sfreq, freq, gain, plot=('magnitude', 'delay'),
                      axes=axes)
    assert len(fig.axes) == 2
    plt.close('all')
    _, axes = plt.subplots(1)
    with pytest.raises(ValueError, match='Length of axes'):
        plot_filter(h, sfreq, freq, gain,
                    plot=('magnitude', 'delay'), axes=axes)
def design_filter(filter_type, f_p, fir_design, trans_bandwidth,
                  filter_length, fir_window):
    if filter_type == 'highpass':
        h = create_filter(np.ones(100000), sfreq, f_p, None,
                          l_trans_bandwidth=trans_bandwidth,
                          filter_length=filter_length,
                          fir_design=fir_design, fir_window=fir_window)
    else:
        h = create_filter(np.ones(100000), sfreq, None, f_p,
                          h_trans_bandwidth=trans_bandwidth,
                          filter_length=filter_length,
                          fir_design=fir_design, fir_window=fir_window)
    return h
예제 #4
0
def test_reporting_fir(phase, fir_window, btype):
    """Test FIR filter reporting."""
    l_freq = 1. if btype == 'bandpass' else None
    fs = 1000.
    with catch_logging() as log:
        x = create_filter(None, fs, l_freq, 40, method='fir',
                          phase=phase, fir_window=fir_window, verbose=True)
    n_taps = len(x)
    log = log.getvalue()
    keys = ['FIR',
            btype,
            fir_window.capitalize(),
            'Filter length: %d samples' % (n_taps,),
            'passband ripple',
            'stopband attenuation',
            ]
    if phase == 'minimum':
        keys += [' causal ']
    else:
        keys += [' non-causal ', ' dB cutoff frequency: 45.00 Hz']
        if btype == 'bandpass':
            keys += [' dB cutoff frequency: 0.50 Hz']
    for key in keys:
        assert key in log
    if phase == 'zero':
        assert '-6 dB cutoff' in log
    elif phase == 'zero-double':
        assert '-12 dB cutoff' in log
    else:
        # XXX Eventually we should figure out where the resulting point is,
        # since the minimum-phase process will change it. For now we don't
        # report it.
        assert phase == 'minimum'
    # Verify some of the filter properties
    if phase == 'zero-double':
        x = np.convolve(x, x)  # effectively what happens
    w, h = freqz(x, worN=10000)
    w *= fs / (2 * np.pi)
    h = np.abs(h)
    # passband
    passes = [np.argmin(np.abs(w - f)) for f in (1, 20, 40)]
    # stopband
    stops = [np.argmin(np.abs(w - 50.))]
    # transition
    mids = [np.argmin(np.abs(w - 45.))]
    # put these where they belong based on filter type
    assert w[0] == 0.
    idx_0 = 0
    idx_0p5 = np.argmin(np.abs(w - 0.5))
    if btype == 'bandpass':
        stops += [idx_0]
        mids += [idx_0p5]
    else:
        passes += [idx_0, idx_0p5]
    assert_allclose(h[passes], 1., atol=0.01)
    attenuation = -20 if phase == 'minimum' else -50
    assert_allclose(h[stops], 0., atol=10 ** (attenuation / 20.))
    if phase != 'minimum':  # haven't worked out the math for this yet
        expected = 0.25 if phase == 'zero-double' else 0.5
        assert_allclose(h[mids], expected, atol=0.01)
예제 #5
0
def test_reporting_fir(phase, fir_window, btype):
    """Test FIR filter reporting."""
    l_freq = 1. if btype == 'bandpass' else None
    fs = 1000.
    with catch_logging() as log:
        x = create_filter(None, fs, l_freq, 40, method='fir',
                          phase=phase, fir_window=fir_window, verbose=True)
    n_taps = len(x)
    log = log.getvalue()
    keys = ['FIR',
            btype,
            fir_window.capitalize(),
            'Filter length: %d samples' % (n_taps,),
            'passband ripple',
            'stopband attenuation',
            ]
    if phase == 'minimum':
        keys += [' causal ']
    else:
        keys += [' non-causal ', ' dB cutoff frequency: 45.00 Hz']
        if btype == 'bandpass':
            keys += [' dB cutoff frequency: 0.50 Hz']
    for key in keys:
        assert key in log
    if phase == 'zero':
        assert '-6 dB cutoff' in log
    elif phase == 'zero-double':
        assert '-12 dB cutoff' in log
    else:
        # XXX Eventually we should figure out where the resulting point is,
        # since the minimum-phase process will change it. For now we don't
        # report it.
        assert phase == 'minimum'
    # Verify some of the filter properties
    if phase == 'zero-double':
        x = np.convolve(x, x)  # effectively what happens
    w, h = freqz(x, worN=10000)
    w *= fs / (2 * np.pi)
    h = np.abs(h)
    # passband
    passes = [np.argmin(np.abs(w - f)) for f in (1, 20, 40)]
    # stopband
    stops = [np.argmin(np.abs(w - 50.))]
    # transition
    mids = [np.argmin(np.abs(w - 45.))]
    # put these where they belong based on filter type
    assert w[0] == 0.
    idx_0 = 0
    idx_0p5 = np.argmin(np.abs(w - 0.5))
    if btype == 'bandpass':
        stops += [idx_0]
        mids += [idx_0p5]
    else:
        passes += [idx_0, idx_0p5]
    assert_allclose(h[passes], 1., atol=0.01)
    attenuation = -20 if phase == 'minimum' else -50
    assert_allclose(h[stops], 0., atol=10 ** (attenuation / 20.))
    if phase != 'minimum':  # haven't worked out the math for this yet
        expected = 0.25 if phase == 'zero-double' else 0.5
        assert_allclose(h[mids], expected, atol=0.01)
예제 #6
0
def test_plot_filter():
    """Test filter plotting."""
    l_freq, h_freq, sfreq = 2., 40., 1000.
    data = np.zeros(5000)
    freq = [0, 2, 40, 50, 500]
    gain = [0, 1, 1, 0, 0]
    h = create_filter(data, sfreq, l_freq, h_freq, fir_design='firwin2')
    plot_filter(h, sfreq)
    plt.close('all')
    plot_filter(h, sfreq, freq, gain)
    plt.close('all')
    iir = create_filter(data, sfreq, l_freq, h_freq, method='iir')
    plot_filter(iir, sfreq)
    plt.close('all')
    plot_filter(iir, sfreq,  freq, gain)
    plt.close('all')
    iir_ba = create_filter(data, sfreq, l_freq, h_freq, method='iir',
                           iir_params=dict(output='ba'))
    plot_filter(iir_ba, sfreq,  freq, gain)
    plt.close('all')
    plot_filter(h, sfreq, freq, gain, fscale='linear')
    plt.close('all')
예제 #7
0
    def __init__(self, scale=500, filt=True):
        app.Canvas.__init__(self,
                            title='EEG - Use your wheel to zoom!',
                            keys='interactive')

        self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
        self.program['a_position'] = y.reshape(-1, 1)
        self.program['a_color'] = color
        self.program['a_index'] = index
        self.program['u_scale'] = (1., 1.)
        self.program['u_size'] = (nrows, ncols)
        self.program['u_n'] = n

        # text
        self.font_size = 48.
        self.names = []
        self.quality = []
        for ii in range(n_chan):
            text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
            self.names.append(text)
            text = visuals.TextVisual('', bold=True, color='white')
            self.quality.append(text)

        self.quality_colors = color_palette("RdYlGn", 11)[::-1]

        self.scale = scale
        self.n_samples = n_samples
        self.filt = filt
        self.af = [1.0]

        self.data_f = np.zeros((n_samples, n_chan))
        self.data = np.zeros((n_samples, n_chan))

        self.bf = create_filter(self.data_f.T,
                                sfreq,
                                3,
                                40.,
                                method='fir',
                                fir_design='firwin')

        zi = lfilter_zi(self.bf, self.af)
        self.filt_state = np.tile(zi, (n_chan, 1)).transpose()

        self._timer = app.Timer('auto', connect=self.on_timer, start=True)
        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))

        self.show()
예제 #8
0
def make_filter(data, sfreq, l_freq, h_freq, method='fir', **mne_kwargs):
    '''
    function for creating a filter to apply to raw data

    Parameters: data -> a numpy array or mne.EpochsArray
                *s_freq -> sampling frequency of the data
                *l_freq -> low pass frequency
                *h_freq -> high pass frequency
                method -> the type of filter we want to build (iir or fir)
                **mne_kwargs -> additional keyword arguments for mne

                *must be a float or convertible to float data type

        
    Returns: numpy array(or dict) if method = 'fir' (or 'iir')
    '''

    sfreq = float(sfreq)
    l_freq = float(l_freq)
    h_freq = float(h_freq)

    if not isinstance(method, str):
        raise Exception('method must be a string')
    else:
        if method.lower() not in ['fir', 'iir']:
            raise Exception("method must be either 'iir' or 'fir'")

    if isinstance(data, mne.EpochsArray):
        raw_data = data.get_data()
    elif isinstance(data, np.ndarray):
        raw_data = data
    elif data is None:
        raw_data = None
        print(
            'warning: data is None, no sanity checking is going to be performed...'
        )
    else:
        raise Exception('data must be either a valid EpochsArray or ndarray')

    custom_filter = mne_filter.create_filter(raw_data,
                                             sfreq,
                                             l_freq,
                                             h_freq,
                                             method=method)

    return custom_filter
예제 #9
0
def filter(
    X: np.ndarray,
    sfreq: float,
    n_chans: int = 4,
    low: float = 3,
    high: float = 40,
    verbose: bool = True,
) -> np.ndarray:
    """Inspired by viewer_v2.py in muse-lsl"""
    window = 10
    n_samples = int(sfreq * window)
    data_f = np.zeros((n_samples, n_chans))

    af = [1.0]
    bf = create_filter(data_f.T, sfreq, low, high, method="fir", verbose=verbose)

    zi = lfilter_zi(bf, af)
    filt_state = np.tile(zi, (n_chans, 1)).transpose()
    filt_samples, filt_state = lfilter(bf, af, X, axis=0, zi=filt_state)

    return filt_samples
예제 #10
0
def channel_filter(
    X: np.ndarray,
    n_chans: int,
    sfreq: int,
    device_backend: str,
    device_name: str,
    low: float = 3,
    high: float = 40,
    verbose: bool = False,
) -> np.ndarray:
    """Inspired by viewer_v2.py in muse-lsl"""
    if device_backend == "muselsl":
        pass
    elif device_backend == "brainflow":
        if 'muse' not in device_name:  # hacky; muse brainflow devices do in fact seem to be in correct units
            X = X / 1000  # adjust scale of readings
    else:
        raise ValueError(f"Unknown backend {device_backend}")

    window = 10
    n_samples = int(sfreq * window)
    data_f = np.zeros((n_samples, n_chans))

    af = [1.0]
    bf = create_filter(data_f.T,
                       sfreq,
                       low,
                       high,
                       method="fir",
                       verbose=verbose)

    zi = lfilter_zi(bf, af)
    filt_state = np.tile(zi, (n_chans, 1)).transpose()
    filt_samples, filt_state = lfilter(bf, af, X, axis=0, zi=filt_state)

    return filt_samples
예제 #11
0
def test_reporting_iir(ftype, btype, order, output):
    """Test IIR filter reporting."""
    fs = 1000.
    l_freq = 1. if btype == 'bandpass' else None
    iir_params = dict(ftype=ftype, order=order, output=output)
    rs = 20 if order == 1 else 80
    if ftype == 'ellip':
        iir_params['rp'] = 3  # dB
        iir_params['rs'] = rs  # attenuation
        pass_tol = np.log10(iir_params['rp']) + 0.01
    else:
        pass_tol = 0.2
    with catch_logging() as log:
        x = create_filter(None,
                          fs,
                          l_freq,
                          40.,
                          method='iir',
                          iir_params=iir_params,
                          verbose=True)
    order_eff = order * (1 + (btype == 'bandpass'))
    if output == 'ba':
        assert len(x['b']) == order_eff + 1
    log = log.getvalue()
    keys = [
        'IIR',
        'zero-phase',
        'two-pass forward and reverse',
        'non-causal',
        btype,
        ftype,
        'Filter order %d' % (order_eff * 2, ),
        'Cutoff ' if btype == 'lowpass' else 'Cutoffs ',
    ]
    dB_decade = -27.74
    if ftype == 'ellip':
        dB_cutoff = -6.0
    elif order == 1 or ftype == 'butter':
        dB_cutoff = -6.02
    else:
        assert ftype == 'bessel'
        assert order == 4
        dB_cutoff = -15.16
    if btype == 'lowpass':
        keys += ['%0.2f dB' % (dB_cutoff, )]
    for key in keys:
        assert key.lower() in log.lower()
    # Verify some of the filter properties
    if output == 'ba':
        w, h = freqz(x['b'], x['a'], worN=10000)
    else:
        w, h = sosfreqz(x['sos'], worN=10000)
    w *= fs / (2 * np.pi)
    h = np.abs(h)
    # passband
    passes = [np.argmin(np.abs(w - 20))]
    # stopband
    decades = [np.argmin(np.abs(w - 400.))]  # one decade
    # transition
    edges = [np.argmin(np.abs(w - 40.))]
    # put these where they belong based on filter type
    assert w[0] == 0.
    idx_0p1 = np.argmin(np.abs(w - 0.1))
    idx_1 = np.argmin(np.abs(w - 1.))
    if btype == 'bandpass':
        edges += [idx_1]
        decades += [idx_0p1]
    else:
        passes += [idx_0p1, idx_1]

    edge_val = 10**(dB_cutoff / 40.)
    assert_allclose(h[edges], edge_val, atol=0.01)
    assert_allclose(h[passes], 1., atol=pass_tol)
    if ftype == 'butter' and btype == 'lowpass':
        attenuation = dB_decade * order
        assert_allclose(h[decades], 10**(attenuation / 20.), rtol=0.01)
    elif ftype == 'ellip':
        assert_array_less(h[decades], 10**(-rs / 20))
예제 #12
0
def test_filters():
    """Test low-, band-, high-pass, and band-stop filters plus resampling."""
    rng = np.random.RandomState(0)
    sfreq = 100
    sig_len_secs = 15

    a = rng.randn(2, sig_len_secs * sfreq)

    # let's test our catchers
    for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
        pytest.raises((ValueError, TypeError),
                      filter_data,
                      a,
                      sfreq,
                      4,
                      8,
                      None,
                      fl,
                      1.0,
                      1.0,
                      fir_design='firwin')
    with pytest.raises(TypeError, match='got <class'):
        filter_data(a,
                    sfreq,
                    4,
                    8,
                    None,
                    1000,
                    1.0,
                    1.0,
                    n_jobs=0.5,
                    phase='zero',
                    fir_design='firwin')
    with pytest.raises(ValueError, match='Invalid value'):
        filter_data(a,
                    sfreq,
                    4,
                    8,
                    None,
                    1000,
                    1.0,
                    1.0,
                    n_jobs='blah',
                    phase='zero',
                    fir_design='firwin')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  100,
                  1.,
                  1.,
                  fir_window='foo')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  10,
                  1.,
                  1.,
                  fir_design='firwin')  # too short
    # > Nyq/2
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  sfreq / 2.,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  -1,
                  None,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    # these should work
    create_filter(None, sfreq, None, None)
    create_filter(a, sfreq, None, None, fir_design='firwin')
    create_filter(a, sfreq, None, None, method='iir')

    # check our short-filter warning:
    with pytest.warns(RuntimeWarning, match='attenuation'):
        # Warning for low attenuation
        filter_data(a, sfreq, 1, 8, filter_length=256, fir_design='firwin2')
    with pytest.warns(RuntimeWarning, match='Increase filter_length'):
        # Warning for too short a filter
        filter_data(a, sfreq, 1, 8, filter_length='0.5s', fir_design='firwin2')

    # try new default and old default
    freqs = fftfreq(a.shape[-1], 1. / sfreq)
    A = np.abs(fft(a))
    kwargs = dict(fir_design='firwin')
    for fl in ['auto', '10s', '5000ms', 1024, 1023]:
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         **kwargs)
        lp = filter_data(a,
                         sfreq,
                         None,
                         8,
                         None,
                         fl,
                         10,
                         1.0,
                         n_jobs=2,
                         **kwargs)
        hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs)
        assert_allclose(hp, bp, rtol=1e-3, atol=2e-3)
        assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3)
        # Sanity check ttenuation
        mask = (freqs > 5.5) & (freqs < 6.5)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.02)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.2)
        # now the minimum-phase versions
        bp = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        bs = filter_data(a,
                         sfreq,
                         8 + 1.0,
                         4 - 1.0,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.11)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.3)

    # and since these are low-passed, downsampling/upsampling should be close
    n_resamp_ignore = 10
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # note that on systems without CUDA, this line serves as a test for a
    # graceful fallback to n_jobs=None
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # test to make sure our resamling matches scipy's
    bp_up_dn = sp_resample(sp_resample(bp,
                                       2 * bp.shape[-1],
                                       axis=-1,
                                       window='boxcar'),
                           bp.shape[-1],
                           window='boxcar',
                           axis=-1)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)

    # make sure we don't alias
    t = np.array(list(range(sfreq * sig_len_secs))) / float(sfreq)
    # make sinusoid close to the Nyquist frequency
    sig = np.sin(2 * np.pi * sfreq / 2.2 * t)
    # signal should disappear with 2x downsampling
    sig_gone = resample(sig, 1, 2)[n_resamp_ignore:-n_resamp_ignore]
    assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)

    # let's construct some filters
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter
    assert iir_params['a'].size - 1 == 3
    assert iir_params['b'].size - 1 == 3
    iir_params = dict(ftype='butter', order=4, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['a'].size - 1 == 4
    assert iir_params['b'].size - 1 == 4
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter, which requires 2 SOS ((2, 6))
    assert iir_params['sos'].shape == (2, 6)
    iir_params = dict(ftype='butter', order=4, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['sos'].shape == (2, 6)

    # check that picks work for 3d array with one channel and picks=[0]
    a = rng.randn(5 * sfreq, 5 * sfreq)
    b = a[:, None, :]

    a_filt = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')
    b_filt = filter_data(b,
                         sfreq,
                         4,
                         8, [0],
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')

    assert_array_equal(a_filt[:, None, :], b_filt)

    # check for n-dimensional case
    a = rng.randn(2, 2, 2, 2)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(ValueError, filter_data, a, sfreq, 4, 8,
                      np.array([0, 1]), 100, 1.0, 1.0)

    # check corner case (#4693)
    want_length = int(round(_length_factors['hamming'] * 1000. / 0.5))
    want_length += (want_length % 2 == 0)
    assert want_length == 6601
    h = create_filter(np.empty(10000),
                      1000.,
                      l_freq=None,
                      h_freq=55.,
                      h_trans_bandwidth=0.5,
                      method='fir',
                      phase='zero-double',
                      fir_design='firwin',
                      verbose=True)
    assert len(h) == 6601
    h = create_filter(np.empty(10000),
                      1000.,
                      l_freq=None,
                      h_freq=55.,
                      h_trans_bandwidth=0.5,
                      method='fir',
                      phase='zero',
                      fir_design='firwin',
                      filter_length='7s',
                      verbose=True)
    assert len(h) == 7001
    h = create_filter(np.empty(10000),
                      1000.,
                      l_freq=None,
                      h_freq=55.,
                      h_trans_bandwidth=0.5,
                      method='fir',
                      phase='zero-double',
                      fir_design='firwin',
                      filter_length='7s',
                      verbose=True)
    assert len(h) == 8193  # next power of two
예제 #13
0
    def __init__(self, lsl_inlet, scale=500, filt=True):
        app.Canvas.__init__(self,
                            title='EEG - Use your wheel to zoom!',
                            keys='interactive')

        self.inlet = lsl_inlet
        info = self.inlet.info()
        description = info.desc()

        window = 10
        self.sfreq = info.nominal_srate()
        n_samples = int(self.sfreq * window)
        self.n_chans = info.channel_count()

        ch = description.child('channels').first_child()
        ch_names = [ch.child_value('label')]

        for i in range(self.n_chans):
            ch = ch.next_sibling()
            ch_names.append(ch.child_value('label'))

        # Number of cols and rows in the table.
        n_rows = self.n_chans
        n_cols = 1

        # Number of signals.
        m = n_rows * n_cols

        # Number of samples per signal.
        n = n_samples

        # Various signal amplitudes.
        amplitudes = np.zeros((m, n)).astype(np.float32)
        # gamma = np.ones((m, n)).astype(np.float32)
        # Generate the signals as a (m, n) array.
        y = amplitudes

        color = color_palette("RdBu_r", n_rows)

        color = np.repeat(color, n, axis=0).astype(np.float32)
        # Signal 2D index of each vertex (row and col) and x-index (sample index
        # within each signal).
        index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
                      np.repeat(np.tile(np.arange(n_rows), n_cols), n),
                      np.tile(np.arange(n), m)].astype(np.float32)

        self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
        self.program['a_position'] = y.reshape(-1, 1)
        self.program['a_color'] = color
        self.program['a_index'] = index
        self.program['u_scale'] = (1., 1.)
        self.program['u_size'] = (n_rows, n_cols)
        self.program['u_n'] = n

        # text
        self.font_size = 48.
        self.names = []
        self.quality = []
        for ii in range(self.n_chans):
            text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
            self.names.append(text)
            text = visuals.TextVisual('', bold=True, color='white')
            self.quality.append(text)

        self.quality_colors = color_palette("RdYlGn", 11)[::-1]

        self.scale = scale
        self.n_samples = n_samples
        self.filt = filt
        self.af = [1.0]

        self.data_f = np.zeros((n_samples, self.n_chans))
        self.data = np.zeros((n_samples, self.n_chans))

        self.bf = create_filter(self.data_f.T,
                                self.sfreq,
                                3,
                                40.,
                                method='fir')

        zi = lfilter_zi(self.bf, self.af)
        self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()

        self._timer = app.Timer('auto', connect=self.on_timer, start=True)
        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))

        self.show()
예제 #14
0
    def __init__(self, lsl_inlet, name = "Unknown", scale=500, filt=True):
        app.Canvas.__init__(self, title='EEG - ' + name,
                            keys='interactive')
        #Setup threading
        t = threading.Thread(target=self.worker)
        t.start()
        self.status = False
        self.name = name
        self.windowData = []
        self.previousWindowData = []
        self.alphaCounter = []
        self.freq = 256
        self.isClosed = False
        self.counter = 0
        self.temp_cal_alplha = []
        self.calibrate_alpha = 0
        self.filename = os.path.join(os.getcwd(), 'recording_' + self.name + '_' + strftime("%Y-%m-%d-%H.%M.%S", gmtime()) + '.csv')
        self.inlet = lsl_inlet
        self.isAction = False
        info = self.inlet.info()
        description = info.desc()
        y_sig = np.sin(2 * np.pi * 8 * np.arange(256) / 256)
        y_sig = y_sig + np.sin(2 * np.pi * 9 * np.arange(256) / 256)
        y_sig = y_sig + np.sin(2 * np.pi * 10 * np.arange(256) / 256)
        y_sig = y_sig + np.sin(2 * np.pi * 11 * np.arange(256) / 256)
        y_sig = y_sig + np.sin(2 * np.pi * 12 * np.arange(256) / 256)
        self.y_sig = y_sig
        window = 10
        self.sfreq = info.nominal_srate()
        n_samples = int(self.sfreq * window)
        self.n_chans = info.channel_count()

        ch = description.child('channels').first_child()
        ch_names = [ch.child_value('label')]

        for i in range(self.n_chans):
            ch = ch.next_sibling()
            ch_names.append(ch.child_value('label'))

        # Number of cols and rows in the table.
        n_rows = self.n_chans
        n_cols = 1

        # Number of signals.
        m = n_rows * n_cols

        # Number of samples per signal.
        n = n_samples

        # Various signal amplitudes.
        amplitudes = np.zeros((m, n)).astype(np.float32)
        # gamma = np.ones((m, n)).astype(np.float32)
        # Generate the signals as a (m, n) array.
        y = amplitudes

        color = color_palette("RdBu_r", n_rows)

        color = np.repeat(color, n, axis=0).astype(np.float32)
        # Signal 2D index of each vertex (row and col) and x-index (sample index
        # within each signal).
        index = np.c_[np.repeat(np.repeat(np.arange(n_cols), n_rows), n),
                      np.repeat(np.tile(np.arange(n_rows), n_cols), n),
                      np.tile(np.arange(n), m)].astype(np.float32)

        self.program = gloo.Program(VERT_SHADER, FRAG_SHADER)
        self.program['a_position'] = y.reshape(-1, 1)
        self.program['a_color'] = color
        self.program['a_index'] = index
        self.program['u_scale'] = (1., 1.)
        self.program['u_size'] = (n_rows, n_cols)
        self.program['u_n'] = n

        # text
        self.font_size = 48.
        self.names = []
        self.quality = []
        for ii in range(self.n_chans):
            text = visuals.TextVisual(ch_names[ii], bold=True, color='white')
            self.names.append(text)
            text = visuals.TextVisual('', bold=True, color='white')
            self.quality.append(text)

        self.quality_colors = color_palette("RdYlGn", 11)[::-1]

        self.scale = scale
        self.n_samples = n_samples
        self.filt = filt
        self.af = [1.0]

        self.data_f = np.zeros((n_samples, self.n_chans))
        self.data = np.zeros((n_samples, self.n_chans))

        self.bf = create_filter(self.data_f.T, self.sfreq, 8, 15.,
                                method='fir')

        zi = lfilter_zi(self.bf, self.af)
        self.filt_state = np.tile(zi, (self.n_chans, 1)).transpose()

        self._timer = app.Timer('auto', connect=self.on_timer, start=True)
        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black', blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))

        self.show()
예제 #15
0
def test_filters():
    """Test low-, band-, high-pass, and band-stop filters plus resampling."""
    sfreq = 100
    sig_len_secs = 15

    a = rng.randn(2, sig_len_secs * sfreq)

    # let's test our catchers
    for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
        assert_raises(ValueError,
                      filter_data,
                      a,
                      sfreq,
                      4,
                      8,
                      None,
                      fl,
                      1.0,
                      1.0,
                      fir_design='firwin')
    for nj in ['blah', 0.5]:
        assert_raises(ValueError,
                      filter_data,
                      a,
                      sfreq,
                      4,
                      8,
                      None,
                      1000,
                      1.0,
                      1.0,
                      n_jobs=nj,
                      phase='zero',
                      fir_design='firwin')
    assert_raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  100,
                  1.,
                  1.,
                  fir_window='foo')
    assert_raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  10,
                  1.,
                  1.,
                  fir_design='firwin')  # too short
    # > Nyq/2
    assert_raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  sfreq / 2.,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    assert_raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  -1,
                  None,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    # these should work
    create_filter(a, sfreq, None, None, fir_design='firwin')
    create_filter(a, sfreq, None, None, method='iir')

    # check our short-filter warning:
    with warnings.catch_warnings(record=True) as w:
        # Warning for low attenuation
        filter_data(a, sfreq, 1, 8, filter_length=256, fir_design='firwin2')
    assert_true(any('attenuation' in str(ww.message) for ww in w))
    with warnings.catch_warnings(record=True) as w:
        # Warning for too short a filter
        filter_data(a, sfreq, 1, 8, filter_length='0.5s', fir_design='firwin2')
    assert_true(any('Increase filter_length' in str(ww.message) for ww in w))

    # try new default and old default
    freqs = fftfreq(a.shape[-1], 1. / sfreq)
    A = np.abs(fft(a))
    kwargs = dict(fir_design='firwin')
    for fl in ['auto', '10s', '5000ms', 1024, 1023]:
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         **kwargs)
        lp = filter_data(a,
                         sfreq,
                         None,
                         8,
                         None,
                         fl,
                         10,
                         1.0,
                         n_jobs=2,
                         **kwargs)
        hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs)
        assert_allclose(hp, bp, rtol=1e-3, atol=1e-3)
        assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3)
        # Sanity check ttenuation
        mask = (freqs > 5.5) & (freqs < 6.5)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.02)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.2)
        # now the minimum-phase versions
        bp = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        bs = filter_data(a,
                         sfreq,
                         8 + 1.0,
                         4 - 1.0,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.11)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.3)

    # and since these are low-passed, downsampling/upsampling should be close
    n_resamp_ignore = 10
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # note that on systems without CUDA, this line serves as a test for a
    # graceful fallback to n_jobs=1
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # test to make sure our resamling matches scipy's
    bp_up_dn = sp_resample(sp_resample(bp,
                                       2 * bp.shape[-1],
                                       axis=-1,
                                       window='boxcar'),
                           bp.shape[-1],
                           window='boxcar',
                           axis=-1)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)

    # make sure we don't alias
    t = np.array(list(range(sfreq * sig_len_secs))) / float(sfreq)
    # make sinusoid close to the Nyquist frequency
    sig = np.sin(2 * np.pi * sfreq / 2.2 * t)
    # signal should disappear with 2x downsampling
    sig_gone = resample(sig, 1, 2)[n_resamp_ignore:-n_resamp_ignore]
    assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)

    # let's construct some filters
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter
    assert_equal(iir_params['a'].size - 1, 3)
    assert_equal(iir_params['b'].size - 1, 3)
    iir_params = dict(ftype='butter', order=4, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert_equal(iir_params['a'].size - 1, 4)
    assert_equal(iir_params['b'].size - 1, 4)
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter, which requires 2 SOS ((2, 6))
    assert_equal(iir_params['sos'].shape, (2, 6))
    iir_params = dict(ftype='butter', order=4, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert_equal(iir_params['sos'].shape, (2, 6))

    # check that picks work for 3d array with one channel and picks=[0]
    a = rng.randn(5 * sfreq, 5 * sfreq)
    b = a[:, None, :]

    a_filt = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')
    b_filt = filter_data(b,
                         sfreq,
                         4,
                         8, [0],
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')

    assert_array_equal(a_filt[:, None, :], b_filt)

    # check for n-dimensional case
    a = rng.randn(2, 2, 2, 2)
    with warnings.catch_warnings(record=True):  # filter too long
        assert_raises(ValueError, filter_data, a, sfreq, 4, 8,
                      np.array([0, 1]), 100, 1.0, 1.0)
예제 #16
0
        if t_end > epochs.tmax:
            break

        window_epoch = epochs.copy().crop(tmin=t_start, tmax=t_end)
        window_epoch = window_epoch.copy().resample(sfreq=500)

        # hjort-csd centered at c3 at 500Hz
        csd = compute_current_source_density(window_epoch,
                                             sphere=sphere_center)
        csd_data = csd.get_data()

        # fir forward-backward filter 8-12 Hz
        from mne.filter import create_filter
        filt = create_filter(data=csd_data,
                             sfreq=sfreq,
                             l_freq=8,
                             h_freq=12,
                             method='iir')

        csd_filtered = csd.filter(l_freq=8, h_freq=12, method='iir')

        # trimming of 64 ms start and end

        # AR prediction Yule-Walker order = 30
        from mne.time_frequency.ar import _yule_walker
        a, e = _yule_walker(csd_filtered.get_data(), order=30)

        # prediction of 128 ms (64 trimmed + 64 future)

        # hilbert of the 128 ms signal
예제 #17
0
    #normalize delay
    left_delay = left_delay * 40 + 100
    right_delay = right_delay * 30 + 70

    print("left_thresh: %d   right_thresh: %d" % (left_thresh, right_thresh))
    print("left_delay: %d   right_delay: %d" % (left_delay, right_delay))
    input()

    """

    n_samples = int(fs * 10)  #10 is window size
    data_fLeft = np.zeros((n_samples, 1))
    data_fRight = np.zeros((n_samples, 1))
    af = [1.0]
    bf = create_filter(data_fLeft.T, fs, 3, 40.0, method="fir")  #do
    zi = lfilter_zi(bf, af)
    dataLeft = np.zeros((n_samples, 1))
    dataRight = np.zeros((n_samples, 1))
    filt_stateLeft = np.tile(zi, (1, 1)).transpose()
    filt_stateRight = np.tile(zi, (1, 1)).transpose()
    oldTimeL = datetime.datetime.now()  # initialize time delta
    oldTimeR = datetime.datetime.now()  # initialize time delta

    try:
        # The following loop acquires data, computes band powers, and calculates neurofeedback metrics based on those band powers
        while True:
            """Add some data at the end of each signal (real-time signals)."""
            samples, timestamps = inlet.pull_chunk(timeout=0.0,
                                                   max_samples=100)
            new_cursor.check_direction()
예제 #18
0
def test_reporting_iir(ftype, btype, order, output):
    """Test IIR filter reporting."""
    fs = 1000.
    l_freq = 1. if btype == 'bandpass' else None
    iir_params = dict(ftype=ftype, order=order, output=output)
    rs = 20 if order == 1 else 80
    if ftype == 'ellip':
        iir_params['rp'] = 3  # dB
        iir_params['rs'] = rs  # attenuation
        pass_tol = np.log10(iir_params['rp']) + 0.01
    else:
        pass_tol = 0.2
    with catch_logging() as log:
        x = create_filter(None, fs, l_freq, 40., method='iir',
                          iir_params=iir_params, verbose=True)
    order_eff = order * (1 + (btype == 'bandpass'))
    if output == 'ba':
        assert len(x['b']) == order_eff + 1
    log = log.getvalue()
    keys = [
        'IIR',
        'zero-phase',
        'two-pass forward and reverse',
        'non-causal',
        btype,
        ftype,
        'Filter order %d' % (order_eff * 2,),
        'Cutoff ' if btype == 'lowpass' else 'Cutoffs ',
    ]
    dB_decade = -27.74
    if ftype == 'ellip':
        dB_cutoff = -6.0
    elif order == 1 or ftype == 'butter':
        dB_cutoff = -6.02
    else:
        assert ftype == 'bessel'
        assert order == 4
        dB_cutoff = -15.16
    if btype == 'lowpass':
        keys += ['%0.2f dB' % (dB_cutoff,)]
    for key in keys:
        assert key.lower() in log.lower()
    # Verify some of the filter properties
    if output == 'ba':
        w, h = freqz(x['b'], x['a'], worN=10000)
    else:
        w, h = _sosfreqz(x['sos'], worN=10000)
    w *= fs / (2 * np.pi)
    h = np.abs(h)
    # passband
    passes = [np.argmin(np.abs(w - 20))]
    # stopband
    decades = [np.argmin(np.abs(w - 400.))]  # one decade
    # transition
    edges = [np.argmin(np.abs(w - 40.))]
    # put these where they belong based on filter type
    assert w[0] == 0.
    idx_0p1 = np.argmin(np.abs(w - 0.1))
    idx_1 = np.argmin(np.abs(w - 1.))
    if btype == 'bandpass':
        edges += [idx_1]
        decades += [idx_0p1]
    else:
        passes += [idx_0p1, idx_1]

    edge_val = 10 ** (dB_cutoff / 40.)
    assert_allclose(h[edges], edge_val, atol=0.01)
    assert_allclose(h[passes], 1., atol=pass_tol)
    if ftype == 'butter' and btype == 'lowpass':
        attenuation = dB_decade * order
        assert_allclose(h[decades], 10 ** (attenuation / 20.), rtol=0.01)
    elif ftype == 'ellip':
        assert_array_less(h[decades], 10 ** (-rs / 20))
예제 #19
0
def test_filters():
    """Test low-, band-, high-pass, and band-stop filters plus resampling."""
    sfreq = 100
    sig_len_secs = 15

    a = rng.randn(2, sig_len_secs * sfreq)

    # let's test our catchers
    for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
        assert_raises(ValueError, filter_data, a, sfreq, 4, 8, None, fl,
                      1.0, 1.0)
    for nj in ['blah', 0.5]:
        assert_raises(ValueError, filter_data, a, sfreq, 4, 8, None, 1000,
                      1.0, 1.0, n_jobs=nj, phase='zero', fir_window='hann')
    assert_raises(ValueError, filter_data, a, sfreq, 4, 8, None, 100,
                  1., 1., fir_window='foo')
    # > Nyq/2
    assert_raises(ValueError, filter_data, a, sfreq, 4, sfreq / 2., None,
                  100, 1.0, 1.0)
    assert_raises(ValueError, filter_data, a, sfreq, -1, None, None,
                  100, 1.0, 1.0)
    # these should work
    create_filter(a, sfreq, None, None)
    create_filter(a, sfreq, None, None, method='iir')

    # check our short-filter warning:
    with warnings.catch_warnings(record=True) as w:
        # Warning for low attenuation
        filter_data(a, sfreq, 1, 8, filter_length=256)
    assert_true(any('attenuation' in str(ww.message) for ww in w))
    with warnings.catch_warnings(record=True) as w:
        # Warning for too short a filter
        filter_data(a, sfreq, 1, 8, filter_length='0.5s')
    assert_true(any('Increase filter_length' in str(ww.message) for ww in w))

    # try new default and old default
    for fl in ['auto', '10s', '5000ms', 1024]:
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         phase='zero', fir_window='hamming')
        lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2,
                         phase='zero', fir_window='hamming')
        hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, phase='zero',
                         fir_window='hamming')
        assert_array_almost_equal(hp, bp, 4)
        assert_array_almost_equal(bp + bs, a, 4)

    # and since these are low-passed, downsampling/upsampling should be close
    n_resamp_ignore = 10
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # note that on systems without CUDA, this line serves as a test for a
    # graceful fallback to n_jobs=1
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # test to make sure our resamling matches scipy's
    bp_up_dn = sp_resample(sp_resample(bp, 2 * bp.shape[-1], axis=-1,
                                       window='boxcar'),
                           bp.shape[-1], window='boxcar', axis=-1)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)

    # make sure we don't alias
    t = np.array(list(range(sfreq * sig_len_secs))) / float(sfreq)
    # make sinusoid close to the Nyquist frequency
    sig = np.sin(2 * np.pi * sfreq / 2.2 * t)
    # signal should disappear with 2x downsampling
    sig_gone = resample(sig, 1, 2)[n_resamp_ignore:-n_resamp_ignore]
    assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)

    # let's construct some filters
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter
    assert_equal(iir_params['a'].size - 1, 3)
    assert_equal(iir_params['b'].size - 1, 3)
    iir_params = dict(ftype='butter', order=4, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert_equal(iir_params['a'].size - 1, 4)
    assert_equal(iir_params['b'].size - 1, 4)
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter, which requires 2 SOS ((2, 6))
    assert_equal(iir_params['sos'].shape, (2, 6))
    iir_params = dict(ftype='butter', order=4, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert_equal(iir_params['sos'].shape, (2, 6))

    # check that picks work for 3d array with one channel and picks=[0]
    a = rng.randn(5 * sfreq, 5 * sfreq)
    b = a[:, None, :]

    a_filt = filter_data(a, sfreq, 4, 8, None, 400, 2.0, 2.0)
    b_filt = filter_data(b, sfreq, 4, 8, [0], 400, 2.0, 2.0)

    assert_array_equal(a_filt[:, None, :], b_filt)

    # check for n-dimensional case
    a = rng.randn(2, 2, 2, 2)
    with warnings.catch_warnings(record=True):  # filter too long
        assert_raises(ValueError, filter_data, a, sfreq, 4, 8,
                      np.array([0, 1]), 100, 1.0, 1.0)
예제 #20
0
def test_filters():
    """Test low-, band-, high-pass, and band-stop filters plus resampling."""
    sfreq = 100
    sig_len_secs = 15

    a = rng.randn(2, sig_len_secs * sfreq)

    # let's test our catchers
    for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
        pytest.raises(ValueError, filter_data, a, sfreq, 4, 8, None, fl,
                      1.0, 1.0, fir_design='firwin')
    for nj in ['blah', 0.5]:
        pytest.raises(ValueError, filter_data, a, sfreq, 4, 8, None, 1000,
                      1.0, 1.0, n_jobs=nj, phase='zero', fir_design='firwin')
    pytest.raises(ValueError, filter_data, a, sfreq, 4, 8, None, 100,
                  1., 1., fir_window='foo')
    pytest.raises(ValueError, filter_data, a, sfreq, 4, 8, None, 10,
                  1., 1., fir_design='firwin')  # too short
    # > Nyq/2
    pytest.raises(ValueError, filter_data, a, sfreq, 4, sfreq / 2., None,
                  100, 1.0, 1.0, fir_design='firwin')
    pytest.raises(ValueError, filter_data, a, sfreq, -1, None, None,
                  100, 1.0, 1.0, fir_design='firwin')
    # these should work
    create_filter(None, sfreq, None, None)
    create_filter(a, sfreq, None, None, fir_design='firwin')
    create_filter(a, sfreq, None, None, method='iir')

    # check our short-filter warning:
    with pytest.warns(RuntimeWarning, match='attenuation'):
        # Warning for low attenuation
        filter_data(a, sfreq, 1, 8, filter_length=256, fir_design='firwin2')
    with pytest.warns(RuntimeWarning, match='Increase filter_length'):
        # Warning for too short a filter
        filter_data(a, sfreq, 1, 8, filter_length='0.5s', fir_design='firwin2')

    # try new default and old default
    freqs = fftfreq(a.shape[-1], 1. / sfreq)
    A = np.abs(fft(a))
    kwargs = dict(fir_design='firwin')
    for fl in ['auto', '10s', '5000ms', 1024, 1023]:
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         **kwargs)
        lp = filter_data(a, sfreq, None, 8, None, fl, 10, 1.0, n_jobs=2,
                         **kwargs)
        hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs)
        assert_allclose(hp, bp, rtol=1e-3, atol=1e-3)
        assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3)
        # Sanity check ttenuation
        mask = (freqs > 5.5) & (freqs < 6.5)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1., atol=0.02)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0., atol=0.2)
        # now the minimum-phase versions
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0,
                         phase='minimum', **kwargs)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         phase='minimum', **kwargs)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1., atol=0.11)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0., atol=0.3)

    # and since these are low-passed, downsampling/upsampling should be close
    n_resamp_ignore = 10
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # note that on systems without CUDA, this line serves as a test for a
    # graceful fallback to n_jobs=1
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # test to make sure our resamling matches scipy's
    bp_up_dn = sp_resample(sp_resample(bp, 2 * bp.shape[-1], axis=-1,
                                       window='boxcar'),
                           bp.shape[-1], window='boxcar', axis=-1)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)

    # make sure we don't alias
    t = np.array(list(range(sfreq * sig_len_secs))) / float(sfreq)
    # make sinusoid close to the Nyquist frequency
    sig = np.sin(2 * np.pi * sfreq / 2.2 * t)
    # signal should disappear with 2x downsampling
    sig_gone = resample(sig, 1, 2)[n_resamp_ignore:-n_resamp_ignore]
    assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)

    # let's construct some filters
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter
    assert iir_params['a'].size - 1 == 3
    assert iir_params['b'].size - 1 == 3
    iir_params = dict(ftype='butter', order=4, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['a'].size - 1 == 4
    assert iir_params['b'].size - 1 == 4
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter, which requires 2 SOS ((2, 6))
    assert iir_params['sos'].shape == (2, 6)
    iir_params = dict(ftype='butter', order=4, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['sos'].shape == (2, 6)

    # check that picks work for 3d array with one channel and picks=[0]
    a = rng.randn(5 * sfreq, 5 * sfreq)
    b = a[:, None, :]

    a_filt = filter_data(a, sfreq, 4, 8, None, 400, 2.0, 2.0,
                         fir_design='firwin')
    b_filt = filter_data(b, sfreq, 4, 8, [0], 400, 2.0, 2.0,
                         fir_design='firwin')

    assert_array_equal(a_filt[:, None, :], b_filt)

    # check for n-dimensional case
    a = rng.randn(2, 2, 2, 2)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(ValueError, filter_data, a, sfreq, 4, 8,
                      np.array([0, 1]), 100, 1.0, 1.0)

    # check corner case (#4693)
    h = create_filter(
        np.empty(10000), 1000., l_freq=None, h_freq=55.,
        h_trans_bandwidth=0.5, method='fir', phase='zero-double',
        fir_design='firwin', verbose=True)
    assert len(h) == 6601