예제 #1
0
def load_ddc(exp_path,
             test,
             electrode,
             drange,
             bandpass=(),
             notches=(),
             save=True,
             snip_transient=True,
             Fs=1000,
             units='nA',
             **extra):

    half = extra.get('half', False)
    avg = extra.get('avg', False)
    # returns data in coulombs, i.e. values are < 1e-12
    (data, disconnected, trigs, pos_edge, chan_map) = _load_cooked(exp_path,
                                                                   test,
                                                                   half=half,
                                                                   avg=avg)
    if half or avg:
        Fs = Fs / 2.0
    if 'a' in units.lower():
        data *= Fs
        data = convert_scale(data, 'a', units)
    elif 'c' in units.lower():
        data = convert_scale(data, 'c', units)

    if bandpass:
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)

    if notches:
        ft.notch_all(data, Fs, lines=notches, inplace=True, filtfilt=True)

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        disconnected = disconnected[..., snip_len:].copy()
        if len(pos_edge):
            trigs = trigs[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()
    dset.data = data
    dset.pos_edge = pos_edge
    dset.trigs = trigs
    dset.ground_chans = disconnected
    dset.Fs = Fs
    dset.chan_map = chan_map
    dset.bandpass = bandpass
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #2
0
def load_openephys_ddc(exp_path,
                       test,
                       electrode,
                       drange,
                       trigger_idx,
                       rec_num='auto',
                       bandpass=(),
                       notches=(),
                       save=False,
                       snip_transient=True,
                       units='nA',
                       **extra):

    rawload = load_open_ephys_channels(exp_path, test, rec_num=rec_num)
    all_chans = rawload.chdata
    Fs = rawload.Fs

    d_chans = len(rows)
    ch_data = all_chans[:d_chans]
    if np.iterable(trigger_idx):
        trigger = all_chans[int(trigger_idx[0])]
    else:
        trigger = all_chans[int(trigger_idx)]

    electrode_chans = rows >= 0
    chan_flat = mat_to_flat((8, 8),
                            rows[electrode_chans],
                            7 - columns[electrode_chans],
                            col_major=False)
    chan_map = ChannelMap(chan_flat, (8, 8), col_major=False, pitch=0.406)

    dr_lo, dr_hi = _dyn_range_lookup[drange]  # drange 0 3 or 7
    ch_data = convert_dyn_range(ch_data, (-2**15, 2**15), (dr_lo, dr_hi))

    data = shm.shared_copy(ch_data[electrode_chans])
    disconnected = ch_data[~electrode_chans]

    trigger -= trigger.mean()
    binary_trig = (trigger > 100).astype('i')
    if binary_trig.any():
        pos_edge = np.where(np.diff(binary_trig) > 0)[0] + 1
    else:
        pos_edge = ()

    # change units if not nA
    if 'a' in units.lower():
        # this puts it as picoamps
        data *= Fs
        data = convert_scale(data, 'pa', units)
    elif 'c' in units.lower():
        data = convert_scale(data, 'pc', units)

    if bandpass:  # how does this logic work?
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)

    if notches:
        ft.notch_all(data, Fs, lines=notches, inplace=True, filtfilt=True)

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        if len(disconnected):
            disconnected = disconnected[..., snip_len:].copy()
        if len(pos_edge):
            trigger = trigger[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()
    dset.data = data
    dset.pos_edge = pos_edge
    dset.trigs = trigger
    dset.ground_chans = disconnected
    dset.Fs = Fs
    dset.chan_map = chan_map
    dset.bandpass = bandpass
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset
예제 #3
0
    def create_dataset(self):
        """
        Maps or loads raw data at the specified sampling rate and with the specified filtering applied.
        The sequence of steps follows a general logic with loading and transformation methods that can be delegated
        to subtypes.

        The final dataset can be either memory-mapped or not and downsampled or not. To avoid unnecessary
        file/memory copies, datasets are created along this path:

        This object has a "data_file" to begin dataset creation with. If downsampling, then a new file must be created
        in the case of mapping and/or saving the results. That new file source is created in
        `create_downsample_file` (candidate for overloading), and supercedes the source "data_file".

        Source file channels are organized into electrode channels and possible grounded input and reference channels.

        The prevailing "data_file" (primary or downsampled) is mapped or loaded in `map_raw_data`. If mapped,
        then MappedSource types are returned, else PlainArraySource types are retured. If downsampling is
        still pending (because the created dataset is neither mapped nor is the downsample saved), then memory
        loading is implied. This is handled by making the downsample conversion directly to memory within
        the `map_raw_data` method (another candidate for overloading).

        Check if read/write access is required for filtering, or because of self.ensure_writeable. If the data
        sources at this point are not writeable (e.g. mapped primary sources), then mirror to writeable files. If the
        dataset is not to be mapped, then promote to memory if necessary.

        Do filtering if necessary.

        Do timing extraction if necessary via `find_trigger_signals` (system specific).

        Returns
        -------
        dataset: Bunch
            Bunch containing ".data" (a DataSource), ".chan_map" (a ChannelMap), and many other metadata attributes.

        """

        channel_map, electrode_chans, ground_chans, ref_chans = self.make_channel_map(
        )

        data_file = self.data_file
        file_is_temp = False
        # "new_downsamp_file" is not None if there was no pre-computed downsample. There are three possibilities:
        # 1. downsample to memory only (not mapped and not saving)
        # 2. downsample to a temp file (mapped and not saving)
        # 3. downsample to a named file (saving -- maybe be eventually mapped or not)
        needs_downsamp = self.new_downsamp_file is not None
        needs_file = needs_downsamp and (self.save_downsamp or self.mapped)
        if needs_file:
            # 1. Do downsample conversion in a subroutine
            # 2. Save to a named file if save_downsamp is True
            # 3. Determine if the new source file is writeable (i.e. a temp file) or not
            downsamp_file = (self.new_downsamp_file
                             if self.save_downsamp else '')
            print('Downsampling to {} Hz from file {} to file {}'.format(
                self.resample_rate, self.data_file, downsamp_file))
            downsamp_path = os.path.split(downsamp_file)[0]
            # if a downsample file needs to be saved, make sure the path exists
            if len(downsamp_path) and not os.path.exists(downsamp_path):
                mkdir_p(downsamp_path)
            data_file = self.create_downsample_file(self.data_file,
                                                    self.resample_rate,
                                                    downsamp_file)
            file_is_temp = not self.save_downsamp
            self.units_scale = convert_scale(1, 'uv', self.units)
            downsample_ratio = 1
        elif needs_downsamp:
            # The "else" case now is that the master electrode source (and ref and ground channels)
            # needs downsampling to PlainArraySources
            downsample_ratio = int(self.raw_sample_rate() / self.resample_rate)
        else:
            downsample_ratio = 1

        open_mode = 'r+' if file_is_temp else 'r'

        Fs = self.raw_sample_rate()
        if self.resample_rate:
            Fs = self.resample_rate

        # Find the full set of expected electrode channels within the "amplifier_data" array
        # electrode_channels = [n for n in range(n_amplifier_channels) if n not in ground_chans + ref_chans]
        if self.load_channels:
            # If load_channels is specified, need to modify the electrode channel list and find the channel map subset
            sub_channels = list()
            sub_indices = list()
            for i, n in enumerate(electrode_chans):
                if n in self.load_channels:
                    sub_channels.append(n)
                    sub_indices.append(i)
            electrode_chans = sub_channels
            channel_map = channel_map.subset(sub_indices)
            ground_chans = [n for n in ground_chans if n in self.load_channels]
            ref_chans = [n for n in ref_chans if n in self.load_channels]

        # Setting up sources should be out-sourced to a method subject to overloading. For example open-ephys data are
        # stored
        # file-per-channel. In the case of loading to memory a full sampling rate recording, the original logic would
        # require packing to HDF5 before loading (inefficient).
        datasource, ground_chans, ref_chans = self.map_raw_data(
            data_file, open_mode, electrode_chans, ground_chans, ref_chans,
            downsample_ratio)

        # Promote to a writeable and possibly RAM-loaded array here if either the final source should be loaded,
        # or if the mapped source is not writeable.
        needs_load = isinstance(datasource, MappedSource) and not self.mapped
        filtering = bool(self.bandpass) or bool(self.notches)
        needs_writeable = (self.ensure_writeable
                           or filtering) and not datasource.writeable
        if needs_load or needs_writeable:
            # Need to make writeable copies of these data sources. If the final source is to be loaded, then mirror
            # to memory here. Copy everything to memory if not mapped, otherwise copy only aligned arrays.
            if not self.mapped or not filtering:
                # Load data if not mapped. If mapped as writeable but not filtering, then copy to new file
                copy_mode = 'all'
            else:
                copy_mode = 'aligned'
            source_type = 'writeable mapped' if self.mapped else 'RAM'
            print('Creating {} sources with copy mode: {}'.format(
                source_type, copy_mode))
            datasource_w = datasource.mirror(mapped=self.mapped,
                                             writeable=True,
                                             copy=copy_mode)
            if ground_chans:
                ground_chans_w = ground_chans.mirror(mapped=self.mapped,
                                                     writeable=True,
                                                     copy=copy_mode)
            if ref_chans:
                ref_chans_w = ref_chans.mirror(mapped=self.mapped,
                                               writeable=True,
                                               copy=copy_mode)
            if not self.mapped or not filtering:
                # swap handles of these objects
                datasource = datasource_w
                datasource_w = None
                if ground_chans:
                    ground_chans = ground_chans_w
                    ground_chans_w = None
                if ref_chans:
                    ref_chans = ref_chans_w
                    ref_chans_w = None
        elif filtering:
            # in this case the datasource was already writeable/loaded
            datasource_w = ground_chans_w = ref_chans_w = None

        # For the filter blocks...
        # If mapped, then datasource and datasource_w will be identical (after filter_array call)
        # If loaded, then datasource_w is None and datasource is filtered in-place
        if self.bandpass:
            # TODO: should filter in two stages for stabilitys
            # filter inplace if the "writeable" source is set to None
            filter_kwargs = dict(
                ftype='butterworth',
                inplace=datasource_w is None,
                design_kwargs=dict(lo=self.bandpass[0],
                                   hi=self.bandpass[1],
                                   Fs=Fs),
                filt_kwargs=dict(filtfilt=not self.causal_filtering))
            if self.mapped:
                # make "verbose" filtering with progress bar if we're filtering a mapped source
                filter_kwargs['filt_kwargs']['verbose'] = True
            print('Bandpass filtering')
            datasource = datasource.filter_array(out=datasource_w,
                                                 **filter_kwargs)
            if ground_chans:
                ground_chans = ground_chans.filter_array(out=ground_chans_w,
                                                         **filter_kwargs)
            if ref_chans:
                ref_chans = ref_chans.filter_array(out=ref_chans_w,
                                                   **filter_kwargs)
        if self.notches:
            print('Notch filtering')
            notch_kwargs = dict(
                inplace=datasource_w is None,
                lines=self.notches,
                filt_kwargs=dict(filtfilt=not self.causal_filtering))
            if self.mapped:
                notch_kwargs['filt_kwargs']['verbose'] = True
            datasource = datasource.notch_filter(Fs,
                                                 out=datasource_w,
                                                 **notch_kwargs)
            if ground_chans:
                ground_chans = ground_chans.notch_filter(Fs,
                                                         out=ground_chans_w,
                                                         **notch_kwargs)
            if ref_chans:
                ref_chans = ref_chans.notch_filter(Fs,
                                                   out=ref_chans_w,
                                                   **notch_kwargs)

        trigger_signal, pos_edge = self.find_trigger_signals(data_file)
        if not needs_file and downsample_ratio > 1:
            trigger_signal = trigger_signal[..., ::downsample_ratio]
            pos_edge = np.round(pos_edge.astype('d') /
                                downsample_ratio).astype('i')

        # Viventi lab convention: stim signal would be on the next available ADC channel... skip explicitly loading
        # this, because the "board_adc_data" array is cojoined with the main datasource

        dataset = Bunch()
        dataset.data = datasource
        for arr in datasource.aligned_arrays:
            dataset[arr] = getattr(datasource, arr)
        dataset.chan_map = channel_map
        dataset.Fs = Fs
        dataset.pos_edge = pos_edge
        dataset.trig_chan = trigger_signal
        dataset.bandpass = self.bandpass
        dataset.notches = self.notches
        dataset.units = self.units
        dataset.transient_snipped = False
        dataset.ground_chans = ground_chans
        dataset.ref_chans = ref_chans
        dataset.loader = self
        return dataset
예제 #4
0
def load_afe(exp_pth,
             test,
             electrode,
             n_data,
             range_code,
             cycle_rate,
             units='nA',
             bandpass=(),
             save=True,
             notches=(),
             snip_transient=True,
             **extra):

    h5 = tables.open_file(os.path.join(exp_pth, test + '.h5'))

    data = h5.root.data[:]
    Fs = h5.root.Fs[0, 0]

    if data.shape[1] > n_data:
        trig_chans = data[:, n_data:]
        trig = np.any(trig_chans > 1, axis=1).astype('i')
        pos_edge = np.where(np.diff(trig) > 0)[0] + 1
    else:
        trig = None
        pos_edge = ()

    data_chans = data[:, :n_data].T.copy(order='C')

    # convert dynamic range to charge or current
    if 'v' not in units.lower():
        pico_coulombs = range_lookup[range_code]
        convert_dyn_range(data_chans, (-1.4, 1.4),
                          pico_coulombs,
                          out=data_chans)
        if 'a' in units.lower():
            # To convert to amps, need to divide coulombs by the
            # integration period. This is found approximately by
            # finding out how many cycles in a scan period were spent
            # integrating. A scan period is now hard coded to be 500
            # cycles. The cycling rate is given as an argument.
            # The integration period for channel i should be:
            # 500 - 2*(n_data - i)
            # That is, the 1st channel is clocked out over two cycles
            # immediately after the integration period. Meanwhile other
            # channels still acquire until they are clocked out.
            n_cycles = 500
            #i_cycles = n_cycles - 2*(n_data - np.arange(n_data))
            i_cycles = n_cycles - 2 * n_data
            i_period = i_cycles / cycle_rate
            data_chans /= i_period  #[:,None]
            convert_scale(data_chans, 'pa', units)
    elif units.lower() != 'v':
        convert_scale(data, 'v', units)

    # only use this one electrode (for now)
    chan_map, disconnected = epins.get_electrode_map('psv_61_afe')[:2]
    connected = np.setdiff1d(np.arange(n_data), disconnected)
    disconnected = disconnected[disconnected < n_data]

    chan_map = chan_map.subset(list(range(len(connected))))

    data = shm.shared_ndarray((len(connected), data_chans.shape[-1]))
    data[:, :] = data_chans[connected]
    ground_chans = data_chans[disconnected].copy()
    del data_chans

    if bandpass:
        # do a little extra to kill DC
        data -= data.mean(axis=1)[:, None]
        (b, a) = ft.butter_bp(lo=bandpass[0], hi=bandpass[1], Fs=Fs)
        filtfilt(data, b, a)
    if notches:
        for freq in notches:
            (b, a) = ft.notch(freq, Fs=Fs, ftype='cheby2')
            filtfilt(data, b, a)

    ## detrend_window = int(round(0.750*Fs))
    ## ft.bdetrend(data, bsize=detrend_window, type='linear', axis=-1)
    else:
        data -= data.mean(axis=1)[:, None]

    if snip_transient:
        snip_len = min(10000, pos_edge[0]) if len(pos_edge) else 10000
        data = data[..., snip_len:].copy()
        ground_chans = ground_chans[..., snip_len:].copy()
        if len(pos_edge):
            trig = trig[..., snip_len:]
            pos_edge -= snip_len

    dset = Bunch()

    dset.data = data
    dset.ground_chans = ground_chans
    dset.chan_map = chan_map
    dset.Fs = Fs
    dset.pos_edge = pos_edge
    dset.bandpass = bandpass
    dset.trig = trig
    dset.transient_snipped = snip_transient
    dset.units = units
    dset.notches = notches
    return dset