Exemplo n.º 1
0
def test_montage_bipolar_02():
    data_wrongorder = create_data(n_trial=2, attr=[
        'chan',
    ])
    data_wrongorder.axis['chan'][1] = data_wrongorder.chan[0][-1::-1]

    with raises(ValueError):
        montage(data_wrongorder, bipolar=100)
def test_trans_timefrequency_stft():
    seed(0)
    data = create_data(n_trial=1, n_chan=2, s_freq=s_freq, time=(0, dur))
    NW = 3
    timefreq = timefrequency(data, method='stft', taper='dpss', NW=NW)
    assert timefreq.list_of_axes == ('chan', 'time', 'freq', 'taper')
    assert timefreq.data[0].shape == (data.number_of('chan')[0], 3, s_freq,
                                      NW * 2 - 1)
def test_trans_frequency_complex():
    seed(0)
    data = create_data(n_trial=1, n_chan=2, s_freq=s_freq, time=(0, dur))
    NW = 3
    freq = frequency(data, output='complex', taper='dpss', NW=NW)
    assert freq.list_of_axes == ('chan', 'freq', 'taper')
    assert freq.data[0].shape == (data.number_of('chan')[0], dur * s_freq,
                                  NW * 2 - 1)
def test_write_read_fieldtrip():
    data = create_data(n_trial=1, n_chan=2)

    data.export(fieldtrip_file, export_format='fieldtrip')
    d = Dataset(fieldtrip_file)
    ftdata = d.read_data()
    assert_array_equal(data.data[0], ftdata.data[0])

    assert len(d.read_markers()) == 0
Exemplo n.º 5
0
def test_pickle_01():
    data = create_data()

    tmpfile = NamedTemporaryFile(delete=False)
    with tmpfile as f:
        dump(data, f)

    with open(tmpfile.name, 'rb') as f:
        loaded = load(f)

    assert_array_equal(data.axis['time'][0], loaded.time[0])
Exemplo n.º 6
0
def test_montage_bipolar_03():
    data_wrongorder = create_data(attr=[
        'chan',
    ])

    # you should get an error if chan is not first
    data_wrongorder.axis = OrderedDict([
        ('time', data_wrongorder.axis['time']),
        ('chan', data_wrongorder.axis['chan']),
    ])

    with raises(ValueError):
        montage(data_wrongorder, bipolar=100)
Exemplo n.º 7
0
def test_copy_axis():
    """Sometimes we remove an axis. So when we copy it, we need to make sure
    that the new dataset doesn't have the removed axis.
    """
    # remove one axis
    data = create_data()
    data = math(data, axis='chan', operator_name='mean')
    assert len(data.axis) == 1

    output = data._copy(axis=True)
    assert len(data.axis) == len(output.axis)

    output = data._copy(axis=False)
    assert len(data.axis) == len(output.axis)
Exemplo n.º 8
0
def test_resample():
    seed(0)
    data = create_data(n_trial=1, s_freq=1024, signal='sine', sine_freq=20)

    NEW_FREQ = 256
    data1 = resample(data, s_freq=NEW_FREQ)

    assert data1.s_freq == NEW_FREQ
    assert data1.data[0].shape[1] == data1.number_of('time')[0]

    freq = frequency(data, taper='boxcar')
    freq1 = frequency(data1, taper='boxcar')
    assert_array_almost_equal(sum(freq.data[0][0, :]),
                              sum(freq1.data[0][0, :]), 4)
def test_brainvision_write():
    data = create_data()
    data.export(brainvision_file, 'brainvision')

    assert brainvision_file.stat().st_size == 822
    assert (data.data[0].size * dtype('float32').itemsize ==
            brainvision_file.with_suffix('.eeg').stat().st_size)

    markers = [
        {'name': 'a', 'start': 1, 'end': 2},
        {'name': 'b', 'start': 4, 'end': 5},
        {'name': 'c', 'start': 10, 'end': 12},
        {'name': 'd', 'start': 15, 'end': 17},
        ]
    data.export(brainvision_file, 'brainvision', markers=markers)
    assert brainvision_file.with_suffix('.vmrk').stat().st_size == 556
Exemplo n.º 10
0
def test_trans_timefrequency_spectrogram():
    seed(0)
    data = create_data(n_trial=1, n_chan=2, s_freq=s_freq, time=(0, dur))
    # the first channel is ~5 times larger than the second channel
    data.data[0][0, :] *= 5

    timefreq = timefrequency(data,
                             method='spectrogram',
                             detrend=None,
                             taper=None,
                             overlap=0)
    p_time = math(data, operator_name=('square', 'sum'), axis='time')
    p_freq = math(timefreq, operator_name='sum', axis='freq')
    assert (4.7**2 < p_freq.data[0][0, :] / p_freq.data[0][1, :]).all()
    assert (p_freq.data[0][0] / p_freq.data[0][1] < 5.7**2).all()

    # with random data, parseval only holds with boxcar
    assert_array_almost_equal(p_time(trial=0),
                              sum(p_freq(trial=0) * data.s_freq, axis=1))
Exemplo n.º 11
0
def test_trans_frequency():
    seed(0)
    data = create_data(n_trial=1, n_chan=2, s_freq=s_freq, time=(0, dur))
    # the first channel is ~5 times larger than the second channel
    data.data[0][0, :] *= 5

    # with random data, parseval only holds with boxcar
    freq = frequency(data, detrend=None, taper=None, scaling='power')
    p_time = math(data, operator_name=('square', 'sum'), axis='time')
    p_freq = math(freq, operator_name='sum', axis='freq')

    assert_array_almost_equal(p_time(trial=0), p_freq(trial=0) * s_freq)

    # one channel is 5 times larger than the other channel,
    # the square of this relationship should hold in freq domain
    freq = frequency(data,
                     detrend=None,
                     taper=None,
                     scaling='power',
                     duration=1)
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert 4.7**2 < (p_freq.data[0][0] / p_freq.data[0][1]) < (5.4**2)

    freq = frequency(data,
                     detrend=None,
                     taper=None,
                     scaling='energy',
                     duration=1)
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert 4.7**2 < (p_freq.data[0][0] / p_freq.data[0][1]) < (5.4**2)

    freq = frequency(data, detrend=None, taper='dpss', scaling='power')
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert 4.7**2 < (p_freq.data[0][0] / p_freq.data[0][1]) < (5.4**2)

    freq = frequency(data, detrend=None, sides='two')
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert 4.7**2 < (p_freq.data[0][0] / p_freq.data[0][1]) < (5.45**2)
Exemplo n.º 12
0

def test_trans_timefrequency_stft():
    seed(0)
    data = create_data(n_trial=1, n_chan=2, s_freq=s_freq, time=(0, dur))
    NW = 3
    timefreq = timefrequency(data, method='stft', taper='dpss', NW=NW)
    assert timefreq.list_of_axes == ('chan', 'time', 'freq', 'taper')
    assert timefreq.data[0].shape == (data.number_of('chan')[0], 3, s_freq,
                                      NW * 2 - 1)


seed(0)
data = create_data(n_trial=1,
                   n_chan=2,
                   s_freq=s_freq,
                   time=(0, dur),
                   amplitude=10)
x = data(trial=0, chan='chan00')


def test_fft_spectrum_01():
    f, t, Sxx = _spectral_helper(x,
                                 x,
                                 fs=s_freq,
                                 window='hann',
                                 nperseg=x.shape[0],
                                 noverlap=0,
                                 nfft=None,
                                 return_onesided=True,
                                 mode='psd',
def test_trans_frequency_doc_01():

    # generate data
    data = create_data(n_chan=2, signal='sine', amplitude=1)

    traces = [
        go.Scatter(
            x=data.time[0],
            y=data(trial=0, chan='chan00'))
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Time (s)'),
        yaxis=dict(
            title='Amplitude (V)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_01_data')

    # default options
    freq = frequency(data, detrend=None)

    traces = [
        go.Scatter(
            x=freq.freq[0],
            y=freq(trial=0, chan='chan00'))
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Frequency (Hz)',
            range=(0, 20)),
        yaxis=dict(
            title='Amplitude (V<sup>2</sup>/Hz)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_02_freq')

    # Parseval's theorem
    p_time = math(data, operator_name=('square', 'sum'), axis='time')
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert_array_almost_equal(p_time(trial=0), p_freq(trial=0) * data.s_freq)

    # generate very long data
    data = create_data(n_chan=1, signal='sine', time=(0, 100))
    freq = frequency(data, taper='hann', duration=1, overlap=0.5)

    traces = [
        go.Scatter(
            x=freq.freq[0],
            y=freq(trial=0, chan='chan00')),
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Frequency (Hz)',
            range=(0, 20)),
        yaxis=dict(
            title='Amplitude (V<sup>2</sup>/Hz)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_03_welch')

    # dpss
    data = create_data(n_chan=1, signal='sine')
    freq = frequency(data, taper='dpss', halfbandwidth=5)

    traces = [
        go.Scatter(
            x=freq.freq[0],
            y=freq(trial=0, chan='chan00')),
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Frequency (Hz)',
            range=(0, 20)),
        yaxis=dict(
            title='Amplitude (V<sup>2</sup>/Hz)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_04_dpss')

    # ESD
    DURATION = 2
    data = create_data(n_chan=1, signal='sine', time=(0, DURATION))
    data.data[0][0, :] *= hann(data.data[0].shape[1])

    traces = [
        go.Scatter(
            x=data.time[0],
            y=data(trial=0, chan='chan00'))
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Time (s)'),
        yaxis=dict(
            title='Amplitude (V)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_05_esd')

    freq = frequency(data, detrend=None, scaling='energy')

    traces = [
        go.Scatter(
            x=freq.freq[0],
            y=freq(trial=0, chan='chan00'))
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Frequency (Hz)',
            range=(0, 20)),
        yaxis=dict(
            title='Amplitude (V<sup>2</sup>)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_06_esd')

    # Parseval's theorem
    p_time = math(data, operator_name=('square', 'sum'), axis='time')
    p_freq = math(freq, operator_name='sum', axis='freq')
    assert_array_almost_equal(p_time(trial=0), p_freq(trial=0) * data.s_freq * DURATION)

    # Complex
    data = create_data(n_chan=1, signal='sine')
    freq = frequency(data, output='complex', sides='two', scaling='energy')

    traces = [
        go.Scatter(
            x=freq.freq[0],
            y=abs(freq(trial=0, chan='chan00', taper=0)))
        ]
    layout = go.Layout(
        xaxis=dict(
            title='Frequency (Hz)'
            ),
        yaxis=dict(
            title='Amplitude (V)'),
        )
    fig = go.Figure(data=traces, layout=layout)
    save_plotly_fig(fig, 'freq_07_complex')
Exemplo n.º 14
0
def test_simulate_01():
    data = create_data()
    assert data.data.dtype == 'O'
    assert data.data.shape == (1, )  # one trial
    assert data.data[0].shape[0] == len(data.axis['chan'][0])
    assert data.data[0].shape[1] == len(data.axis['time'][0])
Exemplo n.º 15
0
from collections import OrderedDict
from numpy import sum, zeros
from numpy.random import seed
from numpy.testing import assert_array_equal, assert_array_almost_equal
from pytest import raises

from wonambi.utils import create_data
from wonambi.trans import montage
from wonambi.trans.montage import compute_average_regress

seed(0)
data = create_data(attr=[
    'chan',
])


def test_montage_01():
    with raises(TypeError):
        montage(data, ref_chan='chan00')


def test_montage_02():
    reref = montage(data, ref_chan=['chan00'])
    dat1 = reref(chan='chan00')
    assert_array_equal(dat1[0], zeros(dat1[0].shape))


def test_montage_03():
    CHAN = ('chan01', 'chan02')
    reref = montage(data, ref_chan=CHAN)
    dat1 = reref(chan=CHAN)
Exemplo n.º 16
0
def test_simulate_07():
    FREQ_LIMITS = (0, 10)
    data = create_data(datatype='ChanFreq', freq=FREQ_LIMITS)
    assert data.axis['freq'][0][0] == FREQ_LIMITS[0]
    assert data.axis['freq'][0][-1] < FREQ_LIMITS[1]
Exemplo n.º 17
0
def test_simulate_channels_00():
    data = create_data(attr=[
        'chan',
    ])
    assert 'chan' in data.attr
Exemplo n.º 18
0
def test_simulate_06():
    TIME_LIMITS = (0, 10)
    data = create_data(time=TIME_LIMITS)
    assert data.axis['time'][0][0] == TIME_LIMITS[0]
    assert data.axis['time'][0][-1] < TIME_LIMITS[1]
Exemplo n.º 19
0
def test_simulate_05():
    data = create_data(datatype='ChanTimeFreq')
    assert data.data[0].shape[0] == len(data.axis['chan'][0])
    assert data.data[0].shape[1] == len(data.axis['time'][0])
    assert data.data[0].shape[2] == len(data.axis['freq'][0])
Exemplo n.º 20
0
def test_montage_bipolar_01():
    data_nochan = create_data()

    with raises(ValueError):
        montage(data_nochan, bipolar=100)
Exemplo n.º 21
0
from wonambi import Dataset
from wonambi.attr.chan import create_sphere_around_elec
from wonambi.attr import Freesurfer, Surf
from wonambi.source import Morph
from wonambi.utils import create_data

from .paths import (
    fs_path,
    surf_path,
    eeglab_hdf5_1_file,
    hdf5_file,
)

fs = Freesurfer(fs_path)
data = create_data()
surf = Surf(surf_path)


def test_import_chan():
    with raises(ImportError):
        create_sphere_around_elec(None, '')


def test_import_anat():
    with raises(ImportError):
        fs.surface_ras_shift

    with raises(ImportError):
        fs.read_label('')
Exemplo n.º 22
0
def test_edf_write():
    data = create_data()
    write_edf(data, EXPORTED_PATH / 'export.edf')
Exemplo n.º 23
0
def test_simulate_02():
    with raises(ValueError):
        create_data(datatype='xxx')
Exemplo n.º 24
0
def test_simulate_03():
    N_TRIAL = 10
    data = create_data(n_trial=N_TRIAL, chan_name=['chan0', 'chan2'])
    assert data.data.shape[0] == N_TRIAL
    assert data.data[0].shape[0] == 2
Exemplo n.º 25
0
from numpy.random import seed

from wonambi.utils import create_data
from wonambi.trans import concatenate

seed(0)
data = create_data(n_trial=5)


def test_concatenate_trial():
    data1 = concatenate(data, axis='trial')
    assert data1.number_of('trial') == 1
    assert data1.list_of_axes == ('chan', 'time', 'trial_axis')


def test_concatenate_axis():
    data1 = concatenate(data, axis='time')
    assert data1.number_of(
        'time')[0] == data.number_of('time')[0] * data.number_of('trial')
Exemplo n.º 26
0
from numpy.testing import assert_array_equal

from wonambi import Dataset
from wonambi.ioeeg import write_wonambi
from wonambi.utils import create_data

from .paths import wonambi_file

gen_data = create_data(n_trial=1)


def test_wonambi_write_read():
    write_wonambi(gen_data, wonambi_file, subj_id='test_subj')
    d = Dataset(wonambi_file)
    data = d.read_data()
    assert_array_equal(data(trial=0), gen_data(trial=0))