コード例 #1
0
ファイル: test_filter.py プロジェクト: zuxfoucault/mne-python
def test_filters():
    """Test low-, band-, high-pass, and band-stop filters plus resampling."""
    sfreq = 100
    sig_len_secs = 15

    a = rng.randn(2, sig_len_secs * sfreq)

    # let's test our catchers
    for fl in ['blah', [0, 1], 1000.5, '10ss', '10']:
        pytest.raises(ValueError,
                      filter_data,
                      a,
                      sfreq,
                      4,
                      8,
                      None,
                      fl,
                      1.0,
                      1.0,
                      fir_design='firwin')
    for nj in ['blah', 0.5]:
        pytest.raises(ValueError,
                      filter_data,
                      a,
                      sfreq,
                      4,
                      8,
                      None,
                      1000,
                      1.0,
                      1.0,
                      n_jobs=nj,
                      phase='zero',
                      fir_design='firwin')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  100,
                  1.,
                  1.,
                  fir_window='foo')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  8,
                  None,
                  10,
                  1.,
                  1.,
                  fir_design='firwin')  # too short
    # > Nyq/2
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  4,
                  sfreq / 2.,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    pytest.raises(ValueError,
                  filter_data,
                  a,
                  sfreq,
                  -1,
                  None,
                  None,
                  100,
                  1.0,
                  1.0,
                  fir_design='firwin')
    # these should work
    create_filter(None, sfreq, None, None)
    create_filter(a, sfreq, None, None, fir_design='firwin')
    create_filter(a, sfreq, None, None, method='iir')

    # check our short-filter warning:
    with pytest.warns(RuntimeWarning, match='attenuation'):
        # Warning for low attenuation
        filter_data(a, sfreq, 1, 8, filter_length=256, fir_design='firwin2')
    with pytest.warns(RuntimeWarning, match='Increase filter_length'):
        # Warning for too short a filter
        filter_data(a, sfreq, 1, 8, filter_length='0.5s', fir_design='firwin2')

    # try new default and old default
    freqs = fftfreq(a.shape[-1], 1. / sfreq)
    A = np.abs(fft(a))
    kwargs = dict(fir_design='firwin')
    for fl in ['auto', '10s', '5000ms', 1024, 1023]:
        bp = filter_data(a, sfreq, 4, 8, None, fl, 1.0, 1.0, **kwargs)
        bs = filter_data(a, sfreq, 8 + 1.0, 4 - 1.0, None, fl, 1.0, 1.0,
                         **kwargs)
        lp = filter_data(a,
                         sfreq,
                         None,
                         8,
                         None,
                         fl,
                         10,
                         1.0,
                         n_jobs=2,
                         **kwargs)
        hp = filter_data(lp, sfreq, 4, None, None, fl, 1.0, 10, **kwargs)
        assert_allclose(hp, bp, rtol=1e-3, atol=1e-3)
        assert_allclose(bp + bs, a, rtol=1e-3, atol=1e-3)
        # Sanity check ttenuation
        mask = (freqs > 5.5) & (freqs < 6.5)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.02)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.2)
        # now the minimum-phase versions
        bp = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        bs = filter_data(a,
                         sfreq,
                         8 + 1.0,
                         4 - 1.0,
                         None,
                         fl,
                         1.0,
                         1.0,
                         phase='minimum',
                         **kwargs)
        assert_allclose(np.mean(np.abs(fft(bp)[:, mask]) / A[:, mask]),
                        1.,
                        atol=0.11)
        assert_allclose(np.mean(np.abs(fft(bs)[:, mask]) / A[:, mask]),
                        0.,
                        atol=0.3)

    # and since these are low-passed, downsampling/upsampling should be close
    n_resamp_ignore = 10
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs=2), 1, 2, n_jobs=2)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # note that on systems without CUDA, this line serves as a test for a
    # graceful fallback to n_jobs=1
    bp_up_dn = resample(resample(bp, 2, 1, n_jobs='cuda'), 1, 2, n_jobs='cuda')
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)
    # test to make sure our resamling matches scipy's
    bp_up_dn = sp_resample(sp_resample(bp,
                                       2 * bp.shape[-1],
                                       axis=-1,
                                       window='boxcar'),
                           bp.shape[-1],
                           window='boxcar',
                           axis=-1)
    assert_array_almost_equal(bp[n_resamp_ignore:-n_resamp_ignore],
                              bp_up_dn[n_resamp_ignore:-n_resamp_ignore], 2)

    # make sure we don't alias
    t = np.array(list(range(sfreq * sig_len_secs))) / float(sfreq)
    # make sinusoid close to the Nyquist frequency
    sig = np.sin(2 * np.pi * sfreq / 2.2 * t)
    # signal should disappear with 2x downsampling
    sig_gone = resample(sig, 1, 2)[n_resamp_ignore:-n_resamp_ignore]
    assert_array_almost_equal(np.zeros_like(sig_gone), sig_gone, 2)

    # let's construct some filters
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter
    assert iir_params['a'].size - 1 == 3
    assert iir_params['b'].size - 1 == 3
    iir_params = dict(ftype='butter', order=4, output='ba')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['a'].size - 1 == 4
    assert iir_params['b'].size - 1 == 4
    iir_params = dict(ftype='cheby1', gpass=1, gstop=20)
    iir_params = construct_iir_filter(iir_params, 40, 80, 1000, 'low')
    # this should be a third order filter, which requires 2 SOS ((2, 6))
    assert iir_params['sos'].shape == (2, 6)
    iir_params = dict(ftype='butter', order=4, output='sos')
    iir_params = construct_iir_filter(iir_params, 40, None, 1000, 'low')
    assert iir_params['sos'].shape == (2, 6)

    # check that picks work for 3d array with one channel and picks=[0]
    a = rng.randn(5 * sfreq, 5 * sfreq)
    b = a[:, None, :]

    a_filt = filter_data(a,
                         sfreq,
                         4,
                         8,
                         None,
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')
    b_filt = filter_data(b,
                         sfreq,
                         4,
                         8, [0],
                         400,
                         2.0,
                         2.0,
                         fir_design='firwin')

    assert_array_equal(a_filt[:, None, :], b_filt)

    # check for n-dimensional case
    a = rng.randn(2, 2, 2, 2)
    with pytest.warns(RuntimeWarning, match='longer'):
        pytest.raises(ValueError, filter_data, a, sfreq, 4, 8,
                      np.array([0, 1]), 100, 1.0, 1.0)

    # check corner case (#4693)
    h = create_filter(np.empty(10000),
                      1000.,
                      l_freq=None,
                      h_freq=55.,
                      h_trans_bandwidth=0.5,
                      method='fir',
                      phase='zero-double',
                      fir_design='firwin',
                      verbose=True)
    assert len(h) == 6601
コード例 #2
0
def _my_trans(data):
    """FFT that adds an additional dimension by repeating result."""
    data_t = fft(data)
    data_t = np.concatenate([data_t[:, :, None], data_t[:, :, None]], axis=2)
    return data_t, None