예제 #1
0
 def lookup(self, dset_name):
     path = self._dset_to_path(dset_name)
     try:
         node = load_bunch(self.dbfile, path)
     except IOError:
         return Bunch()
     except NoSuchNodeError:
         return Bunch()
     return node
예제 #2
0
    def enumerate_conditions(self):
        """
        Return the map of condition labels (counting numbers, beginning
        at 1), as well as tables to decode the labels into multiple
        stimulation parameters.

        Returns
        -------

        conditions : ndarray, len(experiment)
            The condition label (0 <= c < n_conditions) at each stim event

        cond_table : Bunch
            This Bunch contains an entry for every stimulation parameter.
            The entries are lookup tables for the parameter values.

        """

        if not len(self.time_stamps):
            return (), Bunch()
        tab_len = len(self.time_stamps)
        if not self.enum_tables:
            return np.ones(tab_len, 'i'), Bunch()

        all_uvals = []
        conditions = np.zeros(len(self.time_stamps), 'i')
        for name in self.enum_tables:
            tab = self.__dict__[name][:tab_len]
            uvals = np.unique(tab)
            all_uvals.append(uvals)
            conditions *= len(uvals)
            for n, val in enumerate(uvals):
                conditions[tab == val] += n

        n_vals = list(map(len, all_uvals))
        for n in range(len(all_uvals)):
            # first tile the current table of values to
            # match the preceeding "most-significant" values,
            # then repeat the tiled set to match to following
            # "least-significant" values
            utab = all_uvals[n]
            if all_uvals[n + 1:]:
                rep = reduce(np.multiply, n_vals[n + 1:])
                utab = np.repeat(utab, rep)
            if n > 0:
                tiles = reduce(np.multiply, n_vals[:n])
                utab = np.tile(utab, tiles)
            all_uvals[n] = utab

        return conditions, Bunch(**dict(zip(self.enum_tables, all_uvals)))
예제 #3
0
def uniform_bunch_case(b):
    b_lower = Bunch()
    for k, v in b.items():
        if isinstance(k, str):
            b_lower[k.lower()] = v
        else:
            b_lower[k] = v
    return b_lower
예제 #4
0
def load_params(as_string=False):
    cfg = ConfigParser()
    # Look for custom global config in ~/.mjt_exp_conf.txt
    # If nothing found, use a default one here
    cpath = os.path.expanduser('~/.mjt_exp_conf.txt')
    if not os.path.exists(cpath):
        cpath = os.path.split(os.path.abspath(__file__))[0]
        cpath = os.path.join(cpath, 'global_config.txt')
    cfg.read(cpath)

    params = Bunch()
    for opt in cfg.options('globals'):
        if as_string:
            params[opt] = OVERRIDE.get(opt, cfg.get('globals', opt))
        else:
            val = OVERRIDE.get(opt, cfg.get('globals', opt))
            params[opt] = parse_param(opt, val, all_keys)
    for k in all_keys:
        params.setdefault(k, '')
    return params
예제 #5
0
 def __init__(self,
              time_stamps=(),
              event_tables=dict(),
              condition_order=(),
              **attrib):
     if time_stamps is None:
         self.time_stamps = ()
     else:
         self.time_stamps = time_stamps
     self._fill_tables(**event_tables)
     self.stim_props = Bunch(**attrib)
     if condition_order:
         self.set_enum_tables(condition_order)
예제 #6
0
def load_arr(dfile, pruned_pts=(), auto_prune=True, trig=-1):
    try:
        m = sio.loadmat(dfile)
        d = m.pop('data')
        Fs = float(m['Fs'][0, 0])
        nrow = int(m['numRow'][0, 0])
        ncol = int(m['numCol'][0, 0])
        del m
    except NotImplementedError:
        d, shape, Fs = load_hdf5_arr(dfile)
        nrow, ncol = shape

    t = d.shape[0] < d.shape[1]
    if not pruned_pts and auto_prune:
        pruned_pts = get_load_snips(dfile)

    tx = np.arange(max(d.shape)) / Fs
    segs = ()
    if pruned_pts:
        #d = d.T[pruned_pts, :nrow*ncol] if t else d[pruned_pts, :nrow*ncol]
        if t:
            data, _ = pruned_arr(d[:nrow * ncol, :].T, pruned_pts, axis=0)
        else:
            data, _ = pruned_arr(d[:, :nrow * ncol], pruned_pts, axis=0)
        tx, segs = pruned_arr(tx, pruned_pts)
    elif min(d.shape) != nrow * ncol:
        # explicitly make contiguous
        data = d.T[:, :nrow * ncol].copy() if t else d[:, :nrow * ncol].copy()

    if trig >= 0:
        if t:
            trigger_chan = d[trig, :]
        else:
            trigger_chan = d[:, trig]
    else:
        trigger_chan = None
    del d
    array_bunch = Bunch(data=data,
                        rowcol=(nrow, ncol),
                        Fs=Fs,
                        tx=tx,
                        segs=segs,
                        trigger_chan=trigger_chan)
    return array_bunch
예제 #7
0
def load_wireless(exp_path,
                  test,
                  electrode,
                  bandpass=(),
                  notches=(),
                  save=True,
                  snip_transient=True,
                  units='V'):

    data, trigs, Fs, cmap, bpass = load_cooked(exp_path, test, electrode)
    if units.lower() != 'v':
        convert_scale(data, 'v', units)

    dset = Bunch(data=data,
                 pos_edge=trigs,
                 chan_map=cmap,
                 Fs=Fs,
                 bandpass=bpass,
                 units=units)
    return dset
예제 #8
0
    def _get_screen(self, array, channels, chan_map, Fs):
        from ecogdata.expconfig import params
        mem_guideline = float(params.memory_limit)
        n_chan = len(array)
        word_size = array.dtype.itemsize
        n_pts = min(1000000, mem_guideline / n_chan / word_size)
        offset = int(self.screen_start * 60 * Fs)
        n_pts = min(array.shape[1] - offset, n_pts)
        data = np.empty((len(channels), n_pts), dtype=array.dtype)
        for n, c in enumerate(channels):
            data[n] = array[c, offset:offset + n_pts]

        data_bunch = Bunch(data=data,
                           chan_map=chan_map,
                           Fs=Fs,
                           units='au',
                           name='')
        mask = interactive_mask(data_bunch, use_db=False)
        screen_channels = [channels[i] for i in range(len(mask)) if mask[i]]
        screen_map = chan_map.subset(mask)
        return screen_channels, screen_map
예제 #9
0
def cfg_to_bunch(cfg_file, section='', params_table=None):
    """Return session config info in Bunch (dictionary) form with interpolations
    from the master config settings. Perform full evaluation on parameters known
    here and leave subsequent evaluation downstream.
    """
    cp = new_SafeConfigParser()
    cp.read(cfg_file)
    sections = [section] if section else cp.sections()
    b = Bunch()
    if params_table is None:
        params_table = {}
    params_table.update(all_keys)
    for sec in sections:
        bsub = Bunch()
        opts = cp.options(sec)
        param_pairs = [(o, parse_param(o, cp.get(sec, o), params_table))
                       for o in opts]
        bsub.update(param_pairs)
        b[sec] = bsub
    b.sections = sections
    return b
예제 #10
0
    def fill_stims(self, xml_file, ignore_skip=False):
        if self._filled:
            for key in self.event_names:
                del self.__dict__[key]
        # get tick events for good measure
        ticks = TickEvent.walk_events(xml_file)
        for attrib in TickEvent.attr_keys:
            val = ticks.pop(attrib)
            ticks['tick_' + attrib] = val
        data = self.event_type.walk_events(xml_file)
        data.update(ticks)
        keys = list(data.keys())
        if ignore_skip or not self.skip_blocks:
            keep_idx = slice(None)
        else:
            block_id = np.array(data['BlockID'])
            skipped_idx = [np.where(block_id == skip)[0]
                           for skip in self.skip_blocks]
            skipped_idx = np.concatenate(skipped_idx)
            keep_idx = np.setdiff1d(np.arange(len(block_id)), skipped_idx)
        for key in keys:
            arr = data.pop(key)
            data[key] = np.array(arr)[keep_idx]

        # do a second spin through to pick up the units conversions
        context = itertag_wrap(xml_file, 'Environment')
        for _, elem in context:
            units = elem.getchildren()[0]
            # pix size is uncertain .. should be dva
            pix_size = float(units.attrib['PixelSize'])
            # tick duration is in micro-secs
            tick_len = float(units.attrib['TickDuration'])
        # print 'got data:', data.keys()
        # print [len(val) for val in data.values()]
        self.stim_props = Bunch(pix_size=pix_size, tick_len=tick_len)
        self._fill_tables(**data)
        self._filled = True
예제 #11
0
def load_openephys_ddc(exp_path,
                       test,
                       electrode,
                       drange,
                       trigger_idx,
                       rec_num='auto',
                       bandpass=(),
                       notches=(),
                       save=False,
                       snip_transient=True,
                       units='nA',
                       **extra):

    rawload = load_open_ephys_channels(exp_path, test, rec_num=rec_num)
    all_chans = rawload.chdata
    Fs = rawload.Fs

    d_chans = len(rows)
    ch_data = all_chans[:d_chans]
    if np.iterable(trigger_idx):
        trigger = all_chans[int(trigger_idx[0])]
    else:
        trigger = all_chans[int(trigger_idx)]

    electrode_chans = rows >= 0
    chan_flat = mat_to_flat((8, 8),
                            rows[electrode_chans],
                            7 - columns[electrode_chans],
                            col_major=False)
    chan_map = ChannelMap(chan_flat, (8, 8), col_major=False, pitch=0.406)

    dr_lo, dr_hi = _dyn_range_lookup[drange]  # drange 0 3 or 7
    ch_data = convert_dyn_range(ch_data, (-2**15, 2**15), (dr_lo, dr_hi))

    data = shm.shared_copy(ch_data[electrode_chans])
    disconnected = ch_data[~electrode_chans]

    trigger -= trigger.mean()
    binary_trig = (trigger > 100).astype('i')
    if binary_trig.any():
        pos_edge = np.where(np.diff(binary_trig) > 0)[0] + 1
    else:
        pos_edge = ()

    # change units if not nA
    if 'a' in units.lower():
        # this puts it as picoamps
        data *= Fs
        data = convert_scale(data, 'pa', units)
    elif 'c' in units.lower():
        data = convert_scale(data, 'pc', units)

    if bandpass:  # how does this logic work?
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)

    if notches:
        ft.notch_all(data, Fs, lines=notches, inplace=True, filtfilt=True)

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        if len(disconnected):
            disconnected = disconnected[..., snip_len:].copy()
        if len(pos_edge):
            trigger = trigger[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()
    dset.data = data
    dset.pos_edge = pos_edge
    dset.trigs = trigger
    dset.ground_chans = disconnected
    dset.Fs = Fs
    dset.chan_map = chan_map
    dset.bandpass = bandpass
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #12
0
def load_ddc(exp_path,
             test,
             electrode,
             drange,
             bandpass=(),
             notches=(),
             save=True,
             snip_transient=True,
             Fs=1000,
             units='nA',
             **extra):

    half = extra.get('half', False)
    avg = extra.get('avg', False)
    # returns data in coulombs, i.e. values are < 1e-12
    (data, disconnected, trigs, pos_edge, chan_map) = _load_cooked(exp_path,
                                                                   test,
                                                                   half=half,
                                                                   avg=avg)
    if half or avg:
        Fs = Fs / 2.0
    if 'a' in units.lower():
        data *= Fs
        data = convert_scale(data, 'a', units)
    elif 'c' in units.lower():
        data = convert_scale(data, 'c', units)

    if bandpass:
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)

    if notches:
        ft.notch_all(data, Fs, lines=notches, inplace=True, filtfilt=True)

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        disconnected = disconnected[..., snip_len:].copy()
        if len(pos_edge):
            trigs = trigs[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()
    dset.data = data
    dset.pos_edge = pos_edge
    dset.trigs = trigs
    dset.ground_chans = disconnected
    dset.Fs = Fs
    dset.chan_map = chan_map
    dset.bandpass = bandpass
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #13
0
def save_bunch(f,
               path,
               b,
               mode='a',
               overwrite_paths=False,
               compress_arrays=0,
               skip_pickles=False):
    """
    Save a Bunch type to an HDF5 group in a new or existing table.

    Arrays, strings, lists, and various scalar types are saved as
    naturally supported array types. Sub-Bunches are written
    recursively in sub-paths. The remaining Bunch elements are
    pickled, preserving their object classification.

    MappedSource and BufferBase types are not saved, but can be reloaded
    if the corresponding FileLoader is included in the Bunch. This is presently
    limited to one FileLoader per HDF5 Group (or path level).

    Parameters
    ----------
    f: path or open tables file
    path: str
        Path in the HDF5 tree (e.g. /branch/node)
    b: Bunch
        Bunch to store at the path
    mode: str
        File access mode (caution: 'w' overwrites the entire file)
    overwrite_paths: bool
        If True, then an existing path in the HDF5 may be over-written
    compress_arrays: int
        Compression level (>0) for arrays. Arrays uncompressed if 0.
    skip_pickles: bool
        Non-array types are "pickled" as strings in pytables, which may be unpickled by
        Python on loading. For maximum compatibility (e.g. Matlab), skip pickling.

    """

    # * create a new group
    # * save any array-like type natively (esp ndarrays)
    # * save everything else as the pickled ObjectAtom
    # * if there are any sub-bunches, then re-enter method with subgroup

    if not isinstance(f, tables.file.File):
        with closing(tables.open_file(f, mode)) as f:
            return save_bunch(f,
                              path,
                              b,
                              overwrite_paths=overwrite_paths,
                              compress_arrays=compress_arrays,
                              skip_pickles=skip_pickles)
    from ecogdata.devices.load.file2data import FileLoader
    # If we want to overwrite a node, check to see that it exists.
    # If we want an exception when trying to overwrite, that will
    # be caught on f.create_group()
    if overwrite_paths:
        try:
            n = f.get_node(path)
            n._f_remove(recursive=True, force=True)
        except NoSuchNodeError:
            pass
    p, node = os.path.split(path)
    if node:
        f.create_group(p, node, createparents=True)

    sub_bunches = list()
    items = iter(b.items())
    pickle_bunch = Bunch()
    mapped_data = list()
    loader_saved = False

    # 1) create arrays for suitable types
    for key, val in items:
        if isinstance(val, FileLoader):
            loader_saved = True
        if isinstance(val, np.ndarray) and len(val.shape):
            atom = tables.Atom.from_dtype(val.dtype)
            if compress_arrays:
                filters = tables.Filters(complevel=compress_arrays,
                                         complib='zlib')
            else:
                filters = None
            ca = f.create_carray(path,
                                 key,
                                 atom=atom,
                                 shape=val.shape,
                                 filters=filters)
            ca[:] = val
        elif type(val) in _h5_seq_types:
            try:
                f.create_array(path, key, val)
            except (TypeError, ValueError) as e:
                pickle_bunch[key] = val
        elif isinstance(val, _not_pickled):
            mapped_data.append(key)
        elif isinstance(val, Bunch):
            sub_bunches.append((key, val))
        else:
            pickle_bunch[key] = val

    # 2) pickle the remaining items (that are not bunches)
    if len(pickle_bunch):
        if skip_pickles:
            print('Warning: these keys are being skipped on path {}'.format(
                path))
            print(pickle_bunch)
        else:
            p_arr = f.create_vlarray(path,
                                     'b_pickle',
                                     atom=tables.ObjectAtom())
            p_arr.append(pickle_bunch)

    # 3) repeat these steps for any bunch elements that are also bunches
    for n, b in sub_bunches:
        #print 'saving', n, b
        subpath = path + '/' + n if path != '/' else path + n
        save_bunch(f,
                   subpath,
                   b,
                   compress_arrays=compress_arrays,
                   skip_pickles=skip_pickles)

    if mapped_data:
        print('Mapped data was skipped: ' + ', '.join(mapped_data))
        if loader_saved:
            print(
                'A data loader object was saved. Use "attempt_reload=True" with load_bunch to recover data.'
            )
    return
예제 #14
0
def load_open_ephys_channels(exp_path,
                             test,
                             rec_num='auto',
                             shared_array=False,
                             downsamp=1,
                             target_Fs=-1,
                             lowpass_ord=12,
                             page_size=8,
                             save_downsamp=True,
                             use_stored=True,
                             store_path='',
                             quantized=False):

    # first off, check if there is a stored file at target_Fs (if valid)
    if use_stored and target_Fs > 0:
        # Look for a previously downsampled data stash
        fname_part = '*{0}*_Fs{1}.h5'.format(test, int(target_Fs))
        # try store_path (if given) and also exp_path
        for p_ in (store_path, exp_path):
            fname = glob(osp.join(p_, fname_part))
            if len(fname) and osp.exists(fname[0]):
                print('Loading from', fname[0])
                channel_data = load_bunch(fname[0], '/')
                return channel_data

    rec_path, rec_num = prepare_paths(exp_path, test, rec_num)
    trueFs = get_robust_samplingrate(rec_path)
    if downsamp == 1 and target_Fs > 0:
        if trueFs is None:
            # do nothing
            print('Sampling frequency not robustly determined, '
                  'downsample not calculated for {0:.1f} Hz'.format(target_Fs))
            raise ValueError
        else:
            # find the correct (integer) downsample rate
            # to get (approx) target Fs
            # target_fs * downsamp <= Fs
            # downsamp <= Fs / target_fs
            downsamp = int(trueFs // target_Fs)
            print(('downsample rate:', downsamp))

    if downsamp > 1 and quantized:
        print('Cannot return quantized data when downsampling')
        quantized = False
    downsamp = int(downsamp)

    all_files = list()
    for pre in rec_num:
        all_files.extend(glob(osp.join(rec_path, pre + '*.continuous')))
    if not len(all_files):
        raise IOError('No files found')
    c_nums = list()
    chan_files = list()
    aux_files = list()
    aux_nums = list()
    adc_files = list()
    adc_nums = list()
    for f in all_files:
        f_part = osp.splitext(osp.split(f)[1])[0]
        # File names can be: Proc#_{ADC/CH/AUX}[_N].continuous
        # (the last _N part is not always present!! disgard for now)
        f_parts = f_part.split('_')
        if len(f_parts[-1]) == 1 and f_parts[-1] in '0123456789':
            f_parts = f_parts[:-1]
        ch = f_parts[-1]  # last file part is CHx or AUXx
        if ch[0:2] == 'CH':
            chan_files.append(f)
            c_nums.append(int(ch[2:]))
        elif ch[0:3] == 'AUX':  # separate chan and AUX files
            aux_files.append(f)
            aux_nums.append(int(ch[3:]))
        elif ch[0:3] == 'ADC':
            adc_files.append(f)
            adc_nums.append(int(ch[3:]))

    if downsamp > 1:
        (b_lp, a_lp) = cheby2_bp(60, hi=1.0 / downsamp, Fs=2, ord=lowpass_ord)

    def _load_array_block(files, shared_array=False, antialias=True):
        Fs = 1
        dtype = 'h' if quantized else 'd'

        # start on 1st index of 0th block
        n = 1
        b_cnt = 0
        b_idx = 1

        ch_record = OE.loadContinuous(files[0], dtype=np.int16, verbose=False)
        d_len = ch_record['data'].shape[-1]
        sub_len = d_len // downsamp
        if sub_len * downsamp < d_len:
            sub_len += 1
        proc_block = shm.shared_ndarray((page_size, d_len), typecode=dtype)
        proc_block[0] = ch_record['data'].astype('d')
        if shared_array:
            saved_array = shm.shared_ndarray((len(files), sub_len),
                                             typecode=dtype)
        else:
            saved_array = np.zeros((len(files), sub_len), dtype=dtype)

        for f in files[1:]:
            ch_record = OE.loadContinuous(f, dtype=np.int16,
                                          verbose=False)  # load data
            Fs = float(ch_record['header']['sampleRate'])
            proc_block[b_idx] = ch_record['data'].astype(dtype)
            b_idx += 1
            n += 1
            if (b_idx == page_size) or (n == len(files)):
                # do dynamic range conversion and downsampling
                # on a block of data
                if not quantized:
                    proc_block *= ch_record['header']['bitVolts']
                if downsamp > 1 and antialias:
                    filtfilt(proc_block, b_lp, a_lp)
                sl = slice(b_cnt * page_size, n)
                saved_array[sl] = proc_block[:b_idx, ::downsamp]
                # update / reset block counters
                b_idx = 0
                b_cnt += 1

        del proc_block
        while gc.collect():
            pass
        return saved_array, Fs, ch_record['header']

    # sort CH, AUX, and ADC by the channel number
    sorted_chans = np.argsort(c_nums)
    # sorts list ed on sorted_chans
    chan_files = [chan_files[n] for n in sorted_chans]
    chdata, Fs, header = _load_array_block(chan_files,
                                           shared_array=shared_array)

    aux_data = list()
    if len(aux_files) > 0:
        sorted_aux = np.argsort(aux_nums)
        aux_files = [aux_files[n] for n in sorted_aux]
        aux_data, _, _ = _load_array_block(aux_files, antialias=False)

    adc_data = list()
    if len(adc_files) > 0:
        sorted_adc = np.argsort(adc_nums)
        adc_files = [adc_files[n] for n in sorted_adc]
        adc_data, _, _ = _load_array_block(adc_files, antialias=False)

    if not trueFs:
        print('settings.xml not found, relying on sampling rate from '
              'recording header files')
        trueFs = Fs
    if downsamp > 1:
        trueFs /= downsamp
    dset = Bunch(chdata=chdata,
                 aux=aux_data,
                 adc=adc_data,
                 Fs=trueFs,
                 header=header)

    if save_downsamp and downsamp > 1:
        fname = '{0}_Fs{1}.h5'.format(osp.split(rec_path)[-1], int(dset.Fs))
        if not len(store_path):
            store_path = exp_path
        mkdir_p(store_path)
        fname = osp.join(store_path, fname)
        print('saving', fname)
        save_bunch(fname, '/', dset, mode='w')

    return dset
예제 #15
0
    def create_dataset(self):
        """
        Maps or loads raw data at the specified sampling rate and with the specified filtering applied.
        The sequence of steps follows a general logic with loading and transformation methods that can be delegated
        to subtypes.

        The final dataset can be either memory-mapped or not and downsampled or not. To avoid unnecessary
        file/memory copies, datasets are created along this path:

        This object has a "data_file" to begin dataset creation with. If downsampling, then a new file must be created
        in the case of mapping and/or saving the results. That new file source is created in
        `create_downsample_file` (candidate for overloading), and supercedes the source "data_file".

        Source file channels are organized into electrode channels and possible grounded input and reference channels.

        The prevailing "data_file" (primary or downsampled) is mapped or loaded in `map_raw_data`. If mapped,
        then MappedSource types are returned, else PlainArraySource types are retured. If downsampling is
        still pending (because the created dataset is neither mapped nor is the downsample saved), then memory
        loading is implied. This is handled by making the downsample conversion directly to memory within
        the `map_raw_data` method (another candidate for overloading).

        Check if read/write access is required for filtering, or because of self.ensure_writeable. If the data
        sources at this point are not writeable (e.g. mapped primary sources), then mirror to writeable files. If the
        dataset is not to be mapped, then promote to memory if necessary.

        Do filtering if necessary.

        Do timing extraction if necessary via `find_trigger_signals` (system specific).

        Returns
        -------
        dataset: Bunch
            Bunch containing ".data" (a DataSource), ".chan_map" (a ChannelMap), and many other metadata attributes.

        """

        channel_map, electrode_chans, ground_chans, ref_chans = self.make_channel_map(
        )

        data_file = self.data_file
        file_is_temp = False
        # "new_downsamp_file" is not None if there was no pre-computed downsample. There are three possibilities:
        # 1. downsample to memory only (not mapped and not saving)
        # 2. downsample to a temp file (mapped and not saving)
        # 3. downsample to a named file (saving -- maybe be eventually mapped or not)
        needs_downsamp = self.new_downsamp_file is not None
        needs_file = needs_downsamp and (self.save_downsamp or self.mapped)
        if needs_file:
            # 1. Do downsample conversion in a subroutine
            # 2. Save to a named file if save_downsamp is True
            # 3. Determine if the new source file is writeable (i.e. a temp file) or not
            downsamp_file = (self.new_downsamp_file
                             if self.save_downsamp else '')
            print('Downsampling to {} Hz from file {} to file {}'.format(
                self.resample_rate, self.data_file, downsamp_file))
            downsamp_path = os.path.split(downsamp_file)[0]
            # if a downsample file needs to be saved, make sure the path exists
            if len(downsamp_path) and not os.path.exists(downsamp_path):
                mkdir_p(downsamp_path)
            data_file = self.create_downsample_file(self.data_file,
                                                    self.resample_rate,
                                                    downsamp_file)
            file_is_temp = not self.save_downsamp
            self.units_scale = convert_scale(1, 'uv', self.units)
            downsample_ratio = 1
        elif needs_downsamp:
            # The "else" case now is that the master electrode source (and ref and ground channels)
            # needs downsampling to PlainArraySources
            downsample_ratio = int(self.raw_sample_rate() / self.resample_rate)
        else:
            downsample_ratio = 1

        open_mode = 'r+' if file_is_temp else 'r'

        Fs = self.raw_sample_rate()
        if self.resample_rate:
            Fs = self.resample_rate

        # Find the full set of expected electrode channels within the "amplifier_data" array
        # electrode_channels = [n for n in range(n_amplifier_channels) if n not in ground_chans + ref_chans]
        if self.load_channels:
            # If load_channels is specified, need to modify the electrode channel list and find the channel map subset
            sub_channels = list()
            sub_indices = list()
            for i, n in enumerate(electrode_chans):
                if n in self.load_channels:
                    sub_channels.append(n)
                    sub_indices.append(i)
            electrode_chans = sub_channels
            channel_map = channel_map.subset(sub_indices)
            ground_chans = [n for n in ground_chans if n in self.load_channels]
            ref_chans = [n for n in ref_chans if n in self.load_channels]

        # Setting up sources should be out-sourced to a method subject to overloading. For example open-ephys data are
        # stored
        # file-per-channel. In the case of loading to memory a full sampling rate recording, the original logic would
        # require packing to HDF5 before loading (inefficient).
        datasource, ground_chans, ref_chans = self.map_raw_data(
            data_file, open_mode, electrode_chans, ground_chans, ref_chans,
            downsample_ratio)

        # Promote to a writeable and possibly RAM-loaded array here if either the final source should be loaded,
        # or if the mapped source is not writeable.
        needs_load = isinstance(datasource, MappedSource) and not self.mapped
        filtering = bool(self.bandpass) or bool(self.notches)
        needs_writeable = (self.ensure_writeable
                           or filtering) and not datasource.writeable
        if needs_load or needs_writeable:
            # Need to make writeable copies of these data sources. If the final source is to be loaded, then mirror
            # to memory here. Copy everything to memory if not mapped, otherwise copy only aligned arrays.
            if not self.mapped or not filtering:
                # Load data if not mapped. If mapped as writeable but not filtering, then copy to new file
                copy_mode = 'all'
            else:
                copy_mode = 'aligned'
            source_type = 'writeable mapped' if self.mapped else 'RAM'
            print('Creating {} sources with copy mode: {}'.format(
                source_type, copy_mode))
            datasource_w = datasource.mirror(mapped=self.mapped,
                                             writeable=True,
                                             copy=copy_mode)
            if ground_chans:
                ground_chans_w = ground_chans.mirror(mapped=self.mapped,
                                                     writeable=True,
                                                     copy=copy_mode)
            if ref_chans:
                ref_chans_w = ref_chans.mirror(mapped=self.mapped,
                                               writeable=True,
                                               copy=copy_mode)
            if not self.mapped or not filtering:
                # swap handles of these objects
                datasource = datasource_w
                datasource_w = None
                if ground_chans:
                    ground_chans = ground_chans_w
                    ground_chans_w = None
                if ref_chans:
                    ref_chans = ref_chans_w
                    ref_chans_w = None
        elif filtering:
            # in this case the datasource was already writeable/loaded
            datasource_w = ground_chans_w = ref_chans_w = None

        # For the filter blocks...
        # If mapped, then datasource and datasource_w will be identical (after filter_array call)
        # If loaded, then datasource_w is None and datasource is filtered in-place
        if self.bandpass:
            # TODO: should filter in two stages for stabilitys
            # filter inplace if the "writeable" source is set to None
            filter_kwargs = dict(
                ftype='butterworth',
                inplace=datasource_w is None,
                design_kwargs=dict(lo=self.bandpass[0],
                                   hi=self.bandpass[1],
                                   Fs=Fs),
                filt_kwargs=dict(filtfilt=not self.causal_filtering))
            if self.mapped:
                # make "verbose" filtering with progress bar if we're filtering a mapped source
                filter_kwargs['filt_kwargs']['verbose'] = True
            print('Bandpass filtering')
            datasource = datasource.filter_array(out=datasource_w,
                                                 **filter_kwargs)
            if ground_chans:
                ground_chans = ground_chans.filter_array(out=ground_chans_w,
                                                         **filter_kwargs)
            if ref_chans:
                ref_chans = ref_chans.filter_array(out=ref_chans_w,
                                                   **filter_kwargs)
        if self.notches:
            print('Notch filtering')
            notch_kwargs = dict(
                inplace=datasource_w is None,
                lines=self.notches,
                filt_kwargs=dict(filtfilt=not self.causal_filtering))
            if self.mapped:
                notch_kwargs['filt_kwargs']['verbose'] = True
            datasource = datasource.notch_filter(Fs,
                                                 out=datasource_w,
                                                 **notch_kwargs)
            if ground_chans:
                ground_chans = ground_chans.notch_filter(Fs,
                                                         out=ground_chans_w,
                                                         **notch_kwargs)
            if ref_chans:
                ref_chans = ref_chans.notch_filter(Fs,
                                                   out=ref_chans_w,
                                                   **notch_kwargs)

        trigger_signal, pos_edge = self.find_trigger_signals(data_file)
        if not needs_file and downsample_ratio > 1:
            trigger_signal = trigger_signal[..., ::downsample_ratio]
            pos_edge = np.round(pos_edge.astype('d') /
                                downsample_ratio).astype('i')

        # Viventi lab convention: stim signal would be on the next available ADC channel... skip explicitly loading
        # this, because the "board_adc_data" array is cojoined with the main datasource

        dataset = Bunch()
        dataset.data = datasource
        for arr in datasource.aligned_arrays:
            dataset[arr] = getattr(datasource, arr)
        dataset.chan_map = channel_map
        dataset.Fs = Fs
        dataset.pos_edge = pos_edge
        dataset.trig_chan = trigger_signal
        dataset.bandpass = self.bandpass
        dataset.notches = self.notches
        dataset.units = self.units
        dataset.transient_snipped = False
        dataset.ground_chans = ground_chans
        dataset.ref_chans = ref_chans
        dataset.loader = self
        return dataset
예제 #16
0
def load_afe(exp_pth,
             test,
             electrode,
             n_data,
             range_code,
             cycle_rate,
             units='nA',
             bandpass=(),
             save=True,
             notches=(),
             snip_transient=True,
             **extra):

    h5 = tables.open_file(os.path.join(exp_pth, test + '.h5'))

    data = h5.root.data[:]
    Fs = h5.root.Fs[0, 0]

    if data.shape[1] > n_data:
        trig_chans = data[:, n_data:]
        trig = np.any(trig_chans > 1, axis=1).astype('i')
        pos_edge = np.where(np.diff(trig) > 0)[0] + 1
    else:
        trig = None
        pos_edge = ()

    data_chans = data[:, :n_data].T.copy(order='C')

    # convert dynamic range to charge or current
    if 'v' not in units.lower():
        pico_coulombs = range_lookup[range_code]
        convert_dyn_range(data_chans, (-1.4, 1.4),
                          pico_coulombs,
                          out=data_chans)
        if 'a' in units.lower():
            # To convert to amps, need to divide coulombs by the
            # integration period. This is found approximately by
            # finding out how many cycles in a scan period were spent
            # integrating. A scan period is now hard coded to be 500
            # cycles. The cycling rate is given as an argument.
            # The integration period for channel i should be:
            # 500 - 2*(n_data - i)
            # That is, the 1st channel is clocked out over two cycles
            # immediately after the integration period. Meanwhile other
            # channels still acquire until they are clocked out.
            n_cycles = 500
            #i_cycles = n_cycles - 2*(n_data - np.arange(n_data))
            i_cycles = n_cycles - 2 * n_data
            i_period = i_cycles / cycle_rate
            data_chans /= i_period  #[:,None]
            convert_scale(data_chans, 'pa', units)
    elif units.lower() != 'v':
        convert_scale(data, 'v', units)

    # only use this one electrode (for now)
    chan_map, disconnected = epins.get_electrode_map('psv_61_afe')[:2]
    connected = np.setdiff1d(np.arange(n_data), disconnected)
    disconnected = disconnected[disconnected < n_data]

    chan_map = chan_map.subset(list(range(len(connected))))

    data = shm.shared_ndarray((len(connected), data_chans.shape[-1]))
    data[:, :] = data_chans[connected]
    ground_chans = data_chans[disconnected].copy()
    del data_chans

    if bandpass:
        # do a little extra to kill DC
        data -= data.mean(axis=1)[:, None]
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)
    if notches:
        for freq in notches:
            (b, a) = ft.notch(freq, Fs=Fs, ftype='cheby2')
            filtfilt(data, b, a)

    ## detrend_window = int(round(0.750*Fs))
    ## ft.bdetrend(data, bsize=detrend_window, type='linear', axis=-1)
    else:
        data -= data.mean(axis=1)[:, None]

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        ground_chans = ground_chans[..., snip_len:].copy()
        if len(pos_edge):
            trig = trig[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()

    dset.data = data
    dset.ground_chans = ground_chans
    dset.chan_map = chan_map
    dset.Fs = Fs
    dset.pos_edge = pos_edge
    dset.bandpass = bandpass
    dset.trig = trig
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #17
0
def _convert(rec,
             from_unit,
             to_unit,
             inverted=False,
             tfs=(),
             Rct=None,
             Cdl=None,
             Rs=None,
             prewarp=False,
             **sm_kwargs):
    from_type = from_unit[-1]
    to_type = to_unit[-1]

    def _override_tf(b, a):
        b, a = signal.normalize(b, a)
        Rs_, Rct_, Cdl_ = tf_to_circuit(b, a)
        b, a = circuit_to_tf((Rs if Rs else Rs_), (Rct if Rct else Rct_),
                             (Cdl if Cdl else Cdl_))
        return b, a

    if not len(tfs):
        from ecogdata.expconfig import params
        session = rec.name.split('.')[0]
        # backwards compatibility
        session = session.split('/')[-1]
        if session in _transform_lookup:
            # XXX: ideally should find the expected unit scale of the
            # transforms in the transform Bunch -- for now assume pico-scale
            to_scale = convert_scale(1, from_unit, 'p' + from_type)
            from_scale = convert_scale(1, 'p' + to_type, to_unit)
            stash_path = pt.sep.join(
                [params.stash_path, 'devices', 'impedance_transfer'])
            transforms = pt.join(stash_path, _transform_lookup[session])
            tfs = load_bunch(transforms, '/')
        else:
            # do an analog integrator --
            # will be converted to 1 + z / (1 - z) later
            # do this on pico scale also (?)
            to_scale = convert_scale(1, from_unit, 'p' + from_type)
            from_scale = convert_scale(1, 'p' + to_type, to_unit)
            tfs = Bunch(aa=np.array([1, 0]), bb=np.array([0, 1]))
    else:
        to_scale = convert_scale(1, from_unit, 'p' + from_type)
        from_scale = convert_scale(1, 'p' + to_type, to_unit)

    if tfs.aa.ndim > 1:
        from ecogdata.filt.time import bfilter
        conv = rec.deepcopy()
        cmap = conv.chan_map
        bb, aa = smooth_transfer_functions(tfs.bb.T, tfs.aa.T, **sm_kwargs)
        for n, ij in enumerate(zip(*cmap.to_mat())):
            b = bb[ij]
            a = aa[ij]
            b, a = _override_tf(b, a)
            if prewarp:
                T = conv.Fs**-1
                z, p, k = signal.tf2zpk(b, a)
                z = 2 / T * np.tan(z * T / 2)
                p = 2 / T * np.tan(p * T / 2)
                b, a = signal.zpk2tf(z, p, k)
            if inverted:
                zb, za = signal.bilinear(a, b, fs=conv.Fs)
            else:
                zb, za = signal.bilinear(b, a, fs=conv.Fs)
            bfilter(zb, za, conv.data[n], bsize=10000)
    else:
        from ecogdata.parallel.split_methods import bfilter
        # avoid needless copy of data array
        rec_data = rec.pop('data')
        conv = rec.deepcopy()
        rec.data = rec_data
        conv_data = shm.shared_ndarray(rec_data.shape, rec_data.dtype.char)
        conv_data[:] = rec_data
        b = tfs.bb
        a = tfs.aa
        if prewarp:
            T = conv.Fs**-1
            z, p, k = signal.tf2zpk(b, a)
            z = 2 / T * np.tan(z * T / 2)
            p = 2 / T * np.tan(p * T / 2)
            b, a = signal.zpk2tf(z, p, k)
        if inverted:
            zb, za = signal.bilinear(a, b, fs=conv.Fs)
        else:
            zb, za = signal.bilinear(b, a, fs=conv.Fs)
        bfilter(zb, za, conv_data, bsize=10000)
        conv.data = conv_data

    conv.data *= (from_scale * to_scale)
    conv.units = to_unit
    return conv
예제 #18
0
def traverse_table(f,
                   path='/',
                   load=True,
                   scan=False,
                   shared_paths=(),
                   skip_stale_pickles=True,
                   attempt_reload=False):
    # Walk nodes and stuff arrays into the bunch.
    # If we encounter a group, then loop back into this method
    from ecogdata.devices.load.file2data import FileLoader
    if not isinstance(f, tables.file.File):
        if load or scan:
            # If scan is True, load should be forced False here
            load = not scan
            with closing(tables.open_file(f, mode='r')) as f:
                return traverse_table(f,
                                      path=path,
                                      load=load,
                                      scan=scan,
                                      shared_paths=shared_paths,
                                      skip_stale_pickles=skip_stale_pickles,
                                      attempt_reload=attempt_reload)
        else:
            f = tables.open_file(f, mode='r')
            try:
                return traverse_table(f,
                                      path=path,
                                      load=load,
                                      scan=scan,
                                      shared_paths=shared_paths,
                                      skip_stale_pickles=skip_stale_pickles,
                                      attempt_reload=attempt_reload)
            except:
                f.close()
                raise
    if load or scan:
        gbunch = Bunch()
    else:
        gbunch = HDF5Bunch(f)
    (p, g) = os.path.split(path)
    if g == '':
        g = p
    nlist = f.list_nodes(path)
    #for n in f.walk_nodes(where=path):
    for n in nlist:
        if isinstance(n, tables.Array):
            if load:
                if n.dtype.char == 'O':
                    arr = 'Not loaded: ' + n.name
                elif '/'.join([path, n.name]) in shared_paths:
                    arr = shm.shared_ndarray(n.shape)
                    arr[:] = n.read()
                else:
                    arr = n.read()
                if isinstance(arr, np.ndarray) and n.shape:
                    if arr.shape == (1, 1):
                        arr = arr[0, 0]
                        if arr == 0:
                            arr = None
                    else:
                        arr = arr.squeeze()
            else:
                arr = n
            gbunch[n.name] = arr
        elif isinstance(n, tables.VLArray):
            if load:
                try:
                    obj = n.read()[0]
                except (ModuleNotFoundError, PickleError, PicklingError):
                    if not skip_stale_pickles:
                        raise
                    gbunch[n.name] = 'unloadable pickle'
                    continue
                # if it's a generic Bunch Pickle, then update the bunch
                if n.name == 'b_pickle':
                    gbunch.update(obj)
                else:
                    gbunch[n.name] = obj
            else:
                # ignore the empty pickle
                if n.name == 'b_pickle' and n.size_in_memory > 32:
                    gbunch[n.name] = 'unloaded pickle'
        elif isinstance(n, tables.Group):
            gname = n._v_name
            # walk_nodes() includes the current group:
            # don't try to descend into this node!
            if gname == g:
                continue
            if gname == '#refs#':
                continue
            subbunch = traverse_table(f,
                                      path='/'.join([path, gname]),
                                      load=load,
                                      scan=scan,
                                      shared_paths=shared_paths,
                                      skip_stale_pickles=skip_stale_pickles,
                                      attempt_reload=attempt_reload)
            gbunch[gname] = subbunch

        else:
            gbunch[n.name] = 'Not Loaded!'

    this_node = f.get_node(path)
    for attr in this_node._v_attrs._f_list():
        gbunch[attr] = this_node._v_attrs[attr]

    loaders = [v for v in gbunch.values() if isinstance(v, FileLoader)]
    if attempt_reload and loaders:
        for loader in loaders:
            print('Attempting load from {}'.format(loader.primary_data_file))
            dataset = loader.create_dataset()
            new_keys = set(dataset.keys()) - set(gbunch.keys())
            for k in new_keys:
                gbunch[k] = dataset.pop(k)
    return gbunch