예제 #1
0
def test_parfiltfilt():
    from ecogdata.parallel.split_methods import filtfilt as filtfilt_p
    import ecogdata.parallel.sharedmem as shm
    r = np.random.randn(20, 2000)
    b, a = butter_bp(lo=30, hi=100, Fs=1000)

    f1 = filtfilt(b, a, r, axis=1, padtype=None)
    f2 = shm.shared_copy(r)
    # test w/o blocking
    filtfilt_p(f2, b, a, bsize=0)
    assert (f1 == f2).all()
    # test w/ blocking
    f2 = shm.shared_copy(r)
    filtfilt_p(f2, b, a, bsize=234)
    assert (f1 == f2).all()
예제 #2
0
def test_parfilt():
    from ecogdata.parallel.split_methods import bfilter as bfilter_p
    import ecogdata.parallel.sharedmem as shm
    r = np.random.randn(20, 2000)
    b, a = butter_bp(lo=30, hi=100, Fs=1000)
    zi = lfilter_zi(b, a)

    f1, _ = lfilter(b, a, r, axis=1, zi=zi * r[:, :1])
    f2 = shm.shared_copy(r)
    # test w/o blocking
    bfilter_p(b, a, f2, axis=1)
    assert (f1 == f2).all()
    # test w/ blocking
    f2 = shm.shared_copy(r)
    bfilter_p(b, a, f2, bsize=234, axis=1)
    assert (f1 == f2).all()
예제 #3
0
def notch_all(
        arr, Fs, lines=60.0, nzo=3,
        nwid=3.0, inplace=True, nmax=None, **filt_kwargs
):
    """Apply notch filtering to a array timeseries.

    Parameters
    ----------
    arr : ndarray
        timeseries
    Fs : float
        sampling frequency
    lines : [list of] float(s)
        One or more lines to notch.
    nzo : int (default 3)
        Number of zeros for the notch filter (more zeros --> deeper notch).
    nwid : float (default 3)
        Affects distance of notch poles from zeros (smaller --> closer).
        Zeros occur on the unit disk in the z-plane. Note that the
        stability of a digital filter depends on poles being within
        the unit disk.
    nmax : float (optional)
        If set, notch all multiples of (scalar-valued) lines up to nmax.

    Returns
    -------
    notched : ndarray

    """

    if inplace:
        # If filtering inplace, then set the output array as such
        filt_kwargs['out'] = arr
    elif filt_kwargs.get('out', None) is None:
        # If inplace is False and no output array is set,
        # then make the array copy here and do inplace filtering on the copy
        arr_f = shm.shared_copy(arr)
        arr = arr_f
        filt_kwargs['out'] = arr
    # otherwise an output array is set

    if isinstance(lines, (float, int)):
        # repeat lines until nmax
        nf = lines
        if nmax is None:
            nmax = nf
        nmax = min(nmax, Fs / 2.0)
        lines = [nf * i for i in range(1, int(nmax // nf) + 1)]
    else:
        lines = [x for x in lines if x < Fs/2]

    notch_defs = get_default_args(notch)
    notch_defs['nwid'] = nwid
    notch_defs['nzo'] = nzo
    notch_defs['Fs'] = Fs
    for nf in lines:
        notch_defs['fcut'] = nf
        arr_f = filter_array(arr, 'notch', inplace=False, design_kwargs=notch_defs, **filt_kwargs)
    return arr_f
예제 #4
0
def test_parfiltfilt():
    import ecogdata.parallel.sharedmem as shm
    r = np.random.randn(20, 2000)

    design_kwargs = dict(lo=30, hi=100, Fs=1000)
    filt_kwargs = dict(filtfilt=True)
    b, a = butter_bp(**design_kwargs)

    f1 = filtfilt(b, a, r, axis=1, padtype=None)
    # not inplace and serial
    f2 = filter_array(r,
                      inplace=False,
                      block_filter='serial',
                      design_kwargs=design_kwargs,
                      filt_kwargs=filt_kwargs)
    assert np.array_equal(f1, f2), 'serial filter with copy failed'
    # inplace and serial
    f2 = r.copy()
    f3 = filter_array(f2,
                      inplace=True,
                      block_filter='serial',
                      design_kwargs=design_kwargs,
                      filt_kwargs=filt_kwargs)
    assert np.array_equal(f1, f2), 'serial filter inplace failed'
    # not inplace and parallel
    rs = shm.shared_copy(r)
    f2 = filter_array(rs,
                      inplace=False,
                      block_filter='parallel',
                      design_kwargs=design_kwargs,
                      filt_kwargs=filt_kwargs)
    assert np.array_equal(f1, f2), 'parallel filter with copy failed'
    # inplace and serial
    f2 = shm.shared_copy(r)
    f3 = filter_array(f2,
                      inplace=True,
                      block_filter='parallel',
                      design_kwargs=design_kwargs,
                      filt_kwargs=filt_kwargs)
    assert np.array_equal(f1, f2), 'parallel filter inplace failed'
예제 #5
0
def _load_cooked(pth, test, half=False, avg=False):
    # august 21 test -- now using a common test-name prefix
    # with different recording channels appended
    test_pfx = osp.join(pth, test)

    chans = sio.loadmat(test_pfx + '.ndata.mat')['raw_data'].T
    try:
        trigs = sio.loadmat(test_pfx + '.ndatastim.mat')['qw'].T
        trigs = trigs.squeeze()
    except IOError:
        trigs = np.zeros(10)

    _columns = np.roll(columns, 2)
    _rows = np.roll(rows, 2)

    electrode_chans = _rows >= 0

    chan_flat = mat_to_flat((8, 8),
                            _rows[electrode_chans],
                            7 - _columns[electrode_chans],
                            col_major=False)
    chan_map = ChannelMap(chan_flat, (8, 8), col_major=False, pitch=0.406)

    # don't need to convert dynamic range
    #chans = np.zeros(chans_int.shape, dtype='d')
    #dr_lo, dr_hi = _dyn_range_lookup[dyn_range]
    #chans = convert_dyn_range(chans_int, 2**20, (dr_lo, dr_hi))

    if avg:
        chans = 0.5 * (chans[:, 1::2] + chans[:, 0::2])
        trigs = trigs[:, 1::2]
    if half:
        chans = chans[:, 1::2]
        trigs = trigs[:, 1::2]

    data = shm.shared_copy(chans[electrode_chans])
    disconnected = chans[~electrode_chans]

    binary_trig = (np.any(trigs == 1, axis=0)).astype('i')
    if binary_trig.any():
        pos_edge = np.where(np.diff(binary_trig) > 0)[0] + 1
    else:
        pos_edge = ()

    return data, disconnected, trigs, pos_edge, chan_map
예제 #6
0
def _load_cooked_pre_august_2014(pth, test, dyn_range, Fs):
    test_dir = osp.join(pth, test)

    chans_int = sio.loadmat(osp.join(test_dir, 'recs.mat'))['adcreads_sort']
    trigs = sio.loadmat(osp.join(test_dir, 'trigs.mat'))['stim_trig_sort']
    order = sio.loadmat(osp.join(test_dir,
                                 'channels.mat'))['channel_numbers_sort']

    # this tells me how to reorder the row/column vectors above to match
    # the channel order in the file
    order = order[:, -1]

    _columns = columns[order]
    _rows = rows[order]

    electrode_chans = _rows >= 0

    chan_flat = mat_to_flat((8, 8),
                            _rows[electrode_chans],
                            7 - _columns[electrode_chans],
                            col_major=False)
    chan_map = ChannelMap(chan_flat, (8, 8), col_major=False, pitch=0.406)

    chans = np.zeros(chans_int.shape, dtype='d')
    dr_lo, dr_hi = _dyn_range_lookup[dyn_range]
    #chans = (chans_int * ( Fs * (dr_hi - dr_lo) * 2**-20 )) + dr_lo*Fs
    chans = convert_dyn_range(chans_int, 2**20, (dr_lo, dr_hi))

    data = shm.shared_copy(chans[electrode_chans])
    disconnected = chans[~electrode_chans]

    binary_trig = (np.any(trigs == 1, axis=0)).astype('i')
    if binary_trig.any():
        pos_edge = np.where(np.diff(binary_trig) > 0)[0] + 1
    else:
        pos_edge = ()

    return data, disconnected, trigs, pos_edge, chan_map
예제 #7
0
    def cache_slice(self, slicer, not_strided=False, sharedmem=False):
        """
        Caches a slice to yield during iteration. This takes place in a background thread for mapped sources.

        Parameters
        ----------
        slicer: slice
            Array __getitem__ slice spec.
        not_strided: bool
            If True, ensure that sliced array is not strided
        sharedmem: bool
            If True, cast the slice into a shared ctypes array

        """
        if sharedmem:
            self._cache_output = shm.shared_copy(self[slicer])
        elif not_strided:
            output = self[slicer]
            if output.__array_interface__['strides']:
                self._cache_output = output.copy()
            else:
                self._cache_output = output
        else:
            self._cache_output = self[slicer]
예제 #8
0
def load_openephys_ddc(exp_path,
                       test,
                       electrode,
                       drange,
                       trigger_idx,
                       rec_num='auto',
                       bandpass=(),
                       notches=(),
                       save=False,
                       snip_transient=True,
                       units='nA',
                       **extra):

    rawload = load_open_ephys_channels(exp_path, test, rec_num=rec_num)
    all_chans = rawload.chdata
    Fs = rawload.Fs

    d_chans = len(rows)
    ch_data = all_chans[:d_chans]
    if np.iterable(trigger_idx):
        trigger = all_chans[int(trigger_idx[0])]
    else:
        trigger = all_chans[int(trigger_idx)]

    electrode_chans = rows >= 0
    chan_flat = mat_to_flat((8, 8),
                            rows[electrode_chans],
                            7 - columns[electrode_chans],
                            col_major=False)
    chan_map = ChannelMap(chan_flat, (8, 8), col_major=False, pitch=0.406)

    dr_lo, dr_hi = _dyn_range_lookup[drange]  # drange 0 3 or 7
    ch_data = convert_dyn_range(ch_data, (-2**15, 2**15), (dr_lo, dr_hi))

    data = shm.shared_copy(ch_data[electrode_chans])
    disconnected = ch_data[~electrode_chans]

    trigger -= trigger.mean()
    binary_trig = (trigger > 100).astype('i')
    if binary_trig.any():
        pos_edge = np.where(np.diff(binary_trig) > 0)[0] + 1
    else:
        pos_edge = ()

    # change units if not nA
    if 'a' in units.lower():
        # this puts it as picoamps
        data *= Fs
        data = convert_scale(data, 'pa', units)
    elif 'c' in units.lower():
        data = convert_scale(data, 'pc', units)

    if bandpass:  # how does this logic work?
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)

    if notches:
        ft.notch_all(data, Fs, lines=notches, inplace=True, filtfilt=True)

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        if len(disconnected):
            disconnected = disconnected[..., snip_len:].copy()
        if len(pos_edge):
            trigger = trigger[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()
    dset.data = data
    dset.pos_edge = pos_edge
    dset.trigs = trigger
    dset.ground_chans = disconnected
    dset.Fs = Fs
    dset.chan_map = chan_map
    dset.bandpass = bandpass
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #9
0
파일: mux.py 프로젝트: miketrumpis/ecogdata
def load_mux(exp_path,
             test,
             electrode,
             headstage,
             ni_daq_variant='',
             mux_connectors=(),
             bandpass=(),
             notches=(),
             trigger=0,
             bnc=(),
             mux_notches=(),
             save=False,
             snip_transient=True,
             units='uV'):
    """
    Load data from the MUX style headstage acquisition. Data is expected 
    to be organized along columns corresponding to the MUX units. The
    columns following sensor data columns are assumed to be a stimulus
    trigger followed by other BNC channels.

    The electrode information must be provided to determine the
    arrangement of recorded and grounded channels within the sensor
    data column.
    
    This preprocessing routine returns a Bunch container with the
    following items
    
    dset.data : nchan x ntime data array
    dset.ground_chans : m x ntime data array of grounded ADC channels
    dset.bnc : un-MUXed readout of the BNC channel(s)
    dset.chan_map : the channel-to-electrode mapping vector
    dset.Fs : sampling frequency
    dset.name : path + expID for the given data set
    dset.bandpass : bandpass filtering applied (if any)
    dset.trig : the logical value of the trigger channel (at MUX'd Fs)

    * If saving, then a table of the Bunch is written.
    * If snip_transient, then advance the timeseries past the bandpass
      filtering onset transient.
    
    """

    try:
        dset = try_saved(exp_path, test, bandpass)
        return dset
    except DataPathError:
        pass

    # say no to shared memory since it's created later on in this method
    loaded = rawload_mux(exp_path,
                         test,
                         headstage,
                         daq_variant=ni_daq_variant,
                         shm=False)
    channels, Fs, dshape, info = loaded
    nrow, ncol_data = dshape
    if channels.shape[0] >= nrow * ncol_data:
        ncol = channels.shape[0] // nrow
        channels = channels.reshape(ncol, nrow, -1)
    else:
        ncol = channels.shape[0]
        channels.shape = (ncol, -1, nrow)
        channels = channels.transpose(0, 2, 1)

    ## Grab BNC data

    if bnc:
        bnc_chans = [ncol_data + int(b) for b in bnc]
        bnc = np.zeros((len(bnc), nrow * channels.shape[-1]))
        for bc, col in zip(bnc, bnc_chans):
            bc[:] = channels[col].transpose().ravel()
        bnc = bnc.squeeze()

    try:
        trig_chans = channels[ncol_data + trigger].copy()
        pos_edge, trig = process_trigger(trig_chans)
    except IndexError:
        pos_edge = ()
        trig = ()

    ## Realize channel mapping

    chan_map, disconnected, reference = epins.get_electrode_map(
        electrode, connectors=mux_connectors)

    ## Data channels

    # if any pre-processing of multiplexed channels, do it here first
    if mux_notches:

        mux_chans = shm.shared_ndarray((ncol_data, channels.shape[-1], nrow))
        mux_chans[:] = channels[:ncol_data].transpose(0, 2, 1)
        mux_chans.shape = (ncol_data, -1)
        ft.notch_all(mux_chans, Fs, lines=mux_notches, filtfilt=True)
        mux_chans.shape = (ncol_data, channels.shape[-1], nrow)
        channels[:ncol_data] = mux_chans.transpose(0, 2, 1)
        del mux_chans

    rec_chans = channels[:ncol_data].reshape(nrow * ncol_data, -1)

    if units.lower() != 'v':
        convert_scale(rec_chans, 'v', units)

    g_chans = disconnected
    r_chans = reference
    d_chans = np.setdiff1d(np.arange(ncol_data * nrow),
                           np.union1d(g_chans, r_chans))

    data_chans = shm.shared_copy(rec_chans[d_chans])
    if len(g_chans):
        gnd_data = rec_chans[g_chans]
    else:
        gnd_data = ()
    if len(r_chans):
        ref_data = rec_chans[r_chans]
    else:
        ref_data = ()
    del rec_chans
    del channels

    # do highpass filtering for stationarity
    if bandpass:
        # manually remove DC from channels before filtering
        if bandpass[0] > 0:
            data_chans -= data_chans.mean(1)[:, None]
            # do a high order highpass to really crush the crappy
            # low frequency noise
            b, a = ft.butter_bp(lo=bandpass[0], Fs=Fs, ord=5)
            # b, a = ft.cheby1_bp(0.5, lo=bandpass[0], Fs=Fs, ord=5)
        else:
            b = [1]
            a = [1]
        if bandpass[1] > 0:
            b_lp, a_lp = ft.butter_bp(hi=bandpass[1], Fs=Fs, ord=3)
            b = np.convolve(b, b_lp)
            a = np.convolve(a, a_lp)

        filtfilt(data_chans, b, a)
        if len(ref_data):
            with parallel_controller(False):
                ref_data = np.atleast_2d(ref_data)
                filtfilt(ref_data, b, a)
                ref_data = ref_data.squeeze()

    if notches:
        ft.notch_all(data_chans,
                     Fs,
                     lines=notches,
                     inplace=True,
                     filtfilt=True)
        if len(ref_data):
            with parallel_controller(False):
                ref_data = np.atleast_2d(ref_data)
                ft.notch_all(ref_data,
                             Fs,
                             lines=notches,
                             inplace=True,
                             filtfilt=True)
                ref_data = ref_data.squeeze()

    if snip_transient:
        if isinstance(snip_transient, bool):
            snip_len = int(Fs * 5)
        else:
            snip_len = int(Fs * snip_transient)
        if len(pos_edge):
            snip_len = max(0, min(snip_len, pos_edge[0] - int(Fs)))
            pos_edge = pos_edge - snip_len
            trig = trig[..., snip_len:].copy()
        data_chans = data_chans[..., snip_len:].copy()
        gnd_data = gnd_data[..., snip_len:].copy()
        if len(ref_data):
            ref_data = ref_data[..., snip_len:].copy()
        if len(bnc):
            bnc = bnc[..., snip_len * nrow:].copy()

    # do blockwise detrending for stationarity
    ## detrend_window = int(round(0.750*Fs))
    ## ft.bdetrend(data_chans, bsize=detrend_window, type='linear', axis=-1)
    dset = ut.Bunch()
    dset.pos_edge = pos_edge
    dset.data = data_chans
    dset.ground_chans = gnd_data
    dset.ref_chans = ref_data
    dset.bnc = bnc
    dset.chan_map = chan_map
    dset.Fs = Fs
    #dset.name = os.path.join(exp_path, test)
    dset.bandpass = bandpass
    dset.trig = trig
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    dset.info = info

    if save:
        hf = os.path.join(exp_path, test + '_proc.h5')
        save_bunch(hf, '/', dset, mode='w')
    return dset