class ButterworthFilter(PropertiedObject, BaseFilter):
    """Applies Butterworth filter to a time series.

    Keyword Arguments
    -----------------

    time_series
         TimeSeriesX object
    order
         Butterworth filter order
    freq_range: list-like
       Array [min_freq, max_freq] describing the filter range

    """

    _descriptors = [
        TypeValTuple('time_series', TimeSeriesX,
                     TimeSeriesX([0.0], dict(samplerate=1.), dims=['time'])),
        TypeValTuple('order', int, 4),
        TypeValTuple('freq_range', list, [58, 62]),
        TypeValTuple('filt_type', str, 'stop'),
    ]

    def __init__(self, **kwds):
        self.init_attrs(kwds)

    def filter(self):
        """
        Applies Butterwoth filter to input time series and returns filtered TimeSeriesX object

        Returns
        -------
        filtered: TimeSeriesX
            The filtered time series

        """
        time_axis_index = get_axis_index(self.time_series, axis_name='time')
        filtered_array = buttfilt(self.time_series,
                                  self.freq_range,
                                  float(self.time_series['samplerate']),
                                  self.filt_type,
                                  self.order,
                                  axis=time_axis_index)

        coords_dict = {
            coord_name: DataArray(coord.copy())
            for coord_name, coord in list(self.time_series.coords.items())
        }
        coords_dict['samplerate'] = self.time_series['samplerate']
        dims = [dim_name for dim_name in self.time_series.dims]
        filtered_time_series = TimeSeriesX(filtered_array,
                                           dims=dims,
                                           coords=coords_dict)

        # filtered_time_series = TimeSeriesX(filtered_time_series)
        filtered_time_series.attrs = self.time_series.attrs.copy()
        return filtered_time_series
示例#2
0
class ButterworthFilter(PropertiedObject,BaseFilter):

    '''
    Applies Butterworth filter to a time series
    '''
    _descriptors = [
        TypeValTuple('time_series', TimeSeriesX, TimeSeriesX([0.0], dims=['time'])),
        TypeValTuple('order', int, 4),
        TypeValTuple('freq_range', list, [58, 62]),
        TypeValTuple('filt_type', str, 'stop'),
    ]

    def __init__(self, **kwds):
        '''
        Constructor
        :param kwds:allowed values are:
        -------------------------------------
        :param time_series  -  TimeSeriesX object
        :param order -  Butterworth filter order
        :param freq_range -  array of frequencies [min_freq, max_freq] to filter out
        :return: None
        '''
        self.init_attrs(kwds)

    def filter(self):
        '''
        Applies Butterwoth filter to input time series and returns filtered TimeSeriesX object
        :return: TimeSeriesX object
        '''

        from ptsa.filt import buttfilt

        time_axis_index = get_axis_index(self.time_series, axis_name='time')
        filtered_array = buttfilt(self.time_series,
                                  self.freq_range, float(self.time_series['samplerate']), self.filt_type,
                                  self.order, axis=time_axis_index)

        coords_dict = {coord_name: DataArray(coord.copy()) for coord_name, coord in self.time_series.coords.items()}
        coords_dict['samplerate'] = self.time_series['samplerate']
        dims = [dim_name for dim_name in self.time_series.dims]
        filtered_time_series = TimeSeriesX(
            filtered_array,
            dims=dims,
            coords=coords_dict
        )

        # filtered_time_series.attrs['samplerate'] = self.time_series.attrs['samplerate']
        # filtered_time_series.attrs['samplerate'] = self.time_series['samplerate']
        filtered_time_series = TimeSeriesX(filtered_time_series)

        return filtered_time_series
示例#3
0
class EEGReader(PropertiedObject, BaseReader):
    """
    Reader that knows how to read binary eeg files. It can read chunks of the eeg signal based on events input
    or can read entire session if session_dataroot is non empty.

    Keyword Arguments
    -----------------
    channels : np.ndarray
      numpy array of channel labels
    start_time : float
       read start offset in seconds w.r.t to the eegeffset specified in the events recarray
    end_time:
       read end offset in seconds w.r.t to the eegeffset specified in the events recarray
    buffer_time : float
       extra buffer in seconds (subtracted from start read and added to end read)
    events : np.recarray
       numpy recarray representing Events
    session_dataroot : str
       path to session dataroot. When set the reader will read the entire session

    Returns
    -------
    None

    """
    _descriptors = [
        TypeValTuple('channels', np.ndarray, np.array([], dtype='|S3')),
        TypeValTuple('start_time', float, 0.0),
        TypeValTuple('end_time', float, 0.0),
        TypeValTuple('buffer_time', float, 0.0),
        TypeValTuple('events', object, object),
        TypeValTuple('session_dataroot', six.string_types, ''),
    ]

    READER_FILETYPE_DICT = defaultdict(lambda: BaseRawReader)
    READER_FILETYPE_DICT.update({'.h5': H5RawReader})

    def __init__(self, **kwds):
        """

        """
        self.init_attrs(kwds)
        self.removed_corrupt_events = False
        self.event_ok_mask_sorted = None

        assert self.start_time <= self.end_time, \
            'start_time (%s) must be less or equal to end_time(%s) ' % (self.start_time, self.end_time)

        self.read_fcn = self.read_events_data
        if self.session_dataroot:
            self.read_fcn = self.read_session_data
        self.channel_name = 'channels'

    def compute_read_offsets(self, dataroot):
        """
        Reads Parameter file and exracts sampling rate that is used to convert from start_time, end_time, buffer_time
        (expressed in seconds)
        to start_offset, end_offset, buffer_offset expressed as integers indicating number of time series data points (not bytes!)

        :param dataroot: core name of the eeg datafile
        :return: tuple of 3 {int} - start_offset, end_offset, buffer_offset
        """
        p_reader = ParamsReader(dataroot=dataroot)
        params = p_reader.read()
        samplerate = params['samplerate']
        # start_offset = int(np.ceil(self.start_time * samplerate))
        # end_offset = int(np.ceil(self.end_time * samplerate))
        # buffer_offset = int(np.ceil(self.buffer_time * samplerate))

        start_offset = int(np.round(self.start_time * samplerate))
        end_offset = int(np.round(self.end_time * samplerate))
        buffer_offset = int(np.round(self.buffer_time * samplerate))

        return start_offset, end_offset, buffer_offset

    def __create_base_raw_readers(self):
        """
        Creates BaseRawreader for each (unique) dataroot present in events recarray
        :return: list of BaseRawReaders and list of dataroots
        :raises: [IncompatibleDataError] if the readers are not all the same class
        """
        evs = self.events
        dataroots = np.unique(evs.eegfile)
        raw_readers = []
        original_dataroots = []

        for dataroot in dataroots:
            events_with_matched_dataroot = evs[evs.eegfile == dataroot]

            start_offset, end_offset, buffer_offset = self.compute_read_offsets(
                dataroot=dataroot)

            read_size = end_offset - start_offset + 2 * buffer_offset

            # start_offsets = events_with_matched_dataroot.eegoffset + start_offset - buffer_offset
            start_offsets = events_with_matched_dataroot.eegoffset + start_offset - buffer_offset

            brr = self.READER_FILETYPE_DICT[os.path.splitext(dataroot)[-1]](
                dataroot=dataroot,
                channels=self.channels,
                start_offsets=start_offsets,
                read_size=read_size)
            raw_readers.append(brr)

            original_dataroots.append(dataroot)

        return raw_readers, original_dataroots

    def read_session_data(self):
        """
        Reads entire session worth of data

        :return: TimeSeriesX object (channels x events x time) with data for entire session the events dimension has length 1
        """
        brr = self.READER_FILETYPE_DICT[self.session_dataroot](
            dataroot=self.session_dataroot, channels=self.channels)
        session_array, read_ok_mask = brr.read()
        self.channel_name = brr.channel_name

        offsets_axis = session_array['offsets']
        number_of_time_points = offsets_axis.shape[0]
        samplerate = float(session_array['samplerate'])
        physical_time_array = np.arange(number_of_time_points) * (1.0 /
                                                                  samplerate)

        # session_array = session_array.rename({'start_offsets': 'events'})

        session_time_series = TimeSeriesX(
            session_array.values,
            dims=[self.channel_name, 'start_offsets', 'time'],
            coords={
                self.channel_name: session_array[self.channel_name],
                'start_offsets': session_array['start_offsets'],
                'time': physical_time_array,
                'offsets': ('time', session_array['offsets']),
                'samplerate': session_array['samplerate']
            })
        session_time_series.attrs = session_array.attrs.copy()
        session_time_series.attrs['dataroot'] = self.session_dataroot

        return session_time_series

    def removed_bad_data(self):
        return self.removed_corrupt_events

    def get_event_ok_mask(self):
        return self.event_ok_mask_sorted

    def read_events_data(self):
        """
        Reads eeg data for individual event

        :return: TimeSeriesX  object (channels x events x time) with data for individual events
        """
        self.event_ok_mask_sorted = None  # reset self.event_ok_mask_sorted

        evs = self.events

        raw_readers, original_dataroots = self.__create_base_raw_readers()

        # used for restoring original order of the events
        ordered_indices = np.arange(len(evs))
        event_indices_list = []
        events = []

        ts_array_list = []

        event_ok_mask_list = []

        for s, (raw_reader,
                dataroot) in enumerate(zip(raw_readers, original_dataroots)):

            ts_array, read_ok_mask = raw_reader.read()

            event_ok_mask_list.append(np.all(read_ok_mask, axis=0))

            ind = np.atleast_1d(evs.eegfile == dataroot)
            event_indices_list.append(ordered_indices[ind])
            events.append(evs[ind])

            ts_array_list.append(ts_array)

        if not all([
                r.channel_name == raw_readers[0].channel_name
                for r in raw_readers
        ]):
            raise IncompatibleDataError(
                'cannot read monopolar and bipolar data together')

        self.channel_name = raw_readers[0].channel_name
        # print('raw_reader_channel_names: \n%s'%[x.channel_name for x in raw_readers])
        # print('self.channel_name: %s'%self.channel_name)

        event_indices_array = np.hstack(event_indices_list)

        event_indices_restore_sort_order_array = event_indices_array.argsort()

        start_extend_time = time.time()
        # new code
        eventdata = xr.concat(ts_array_list, dim='start_offsets')
        # tdim = np.linspace(self.start_time-self.buffer_time,self.end_time+self.buffer_time,num=eventdata['offsets'].shape[0])
        # samplerate=eventdata.attrs['samplerate'].data
        samplerate = float(eventdata['samplerate'])
        tdim = np.arange(eventdata.shape[-1]) * (1.0 / samplerate) + (
            self.start_time - self.buffer_time)
        cdim = eventdata[self.channel_name]
        edim = np.concatenate(events).view(np.recarray).copy()

        attrs = eventdata.attrs.copy()
        # constructing TimeSeries Object
        # eventdata = TimeSeriesX(eventdata.data,dims=['channels','events','time'],coords=[cdim,edim,tdim])
        eventdata = TimeSeriesX(eventdata.data,
                                dims=[self.channel_name, 'events', 'time'],
                                coords={
                                    self.channel_name: cdim,
                                    'events': edim,
                                    'time': tdim,
                                    'samplerate': samplerate
                                })

        eventdata.attrs = attrs

        # restoring original order of the events
        eventdata = eventdata[:, event_indices_restore_sort_order_array, :]

        event_ok_mask = np.hstack(event_ok_mask_list)
        event_ok_mask_sorted = event_ok_mask[
            event_indices_restore_sort_order_array]
        #removing bad events
        if np.any(~event_ok_mask_sorted):
            self.removed_corrupt_events = True
            self.event_ok_mask_sorted = event_ok_mask_sorted

        eventdata = eventdata[:, event_ok_mask_sorted, :]

        return eventdata

    # def read_events_data(self):
    #     """
    #     Reads eeg data for individual event
    #     :return: TimeSeriesX  object (channels x events x time) with data for individual events
    #     """
    #     evs = self.events
    #
    #     raw_readers, original_dataroots = self.__create_base_raw_readers()
    #
    #     # used for restoring original order of the events
    #     ordered_indices = np.arange(len(evs))
    #     event_indices_list = []
    #     events = []
    #
    #     ts_array_list = []
    #
    #     for s, (raw_reader, dataroot) in enumerate(zip(raw_readers, original_dataroots)):
    #         ind = np.atleast_1d(evs.eegfile == dataroot)
    #         event_indices_list.append(ordered_indices[ind])
    #         events.append(evs[ind])
    #
    #         ts_array = raw_reader.read()
    #
    #         read_ok_mask = raw_reader.get_read_ok_mask()
    #
    #         ts_array_list.append(ts_array)
    #
    #     event_indices_array = np.hstack(event_indices_list)
    #
    #     event_indices_restore_sort_order_array = event_indices_array.argsort()
    #
    #     start_extend_time = time.time()
    #     # new code
    #     eventdata = xr.concat(ts_array_list, dim='start_offsets')
    #     # tdim = np.linspace(self.start_time-self.buffer_time,self.end_time+self.buffer_time,num=eventdata['offsets'].shape[0])
    #     # samplerate=eventdata.attrs['samplerate'].data
    #     samplerate = float(eventdata['samplerate'])
    #     tdim = np.arange(eventdata.shape[-1]) * (1.0 / samplerate) + (self.start_time - self.buffer_time)
    #     cdim = eventdata['channels']
    #     edim = np.concatenate(events).view(np.recarray).copy()
    #
    #     attrs = eventdata.attrs.copy()
    #     # constructing TimeSeries Object
    #     # eventdata = TimeSeriesX(eventdata.data,dims=['channels','events','time'],coords=[cdim,edim,tdim])
    #     eventdata = TimeSeriesX(eventdata.data,
    #                             dims=['channels', 'events', 'time'],
    #                             coords={'channels': cdim,
    #                                     'events': edim,
    #                                     'time': tdim,
    #                                     'samplerate': samplerate
    #                                     }
    #                             )
    #
    #     eventdata.attrs = attrs
    #
    #     # restoring original order of the events
    #     eventdata = eventdata[:, event_indices_restore_sort_order_array, :]
    #
    #     return eventdata

    def read(self):
        """
        Calls read_events_data or read_session_data depending on user selection

        :return: TimeSeriesX object
        """
        return self.read_fcn()
class CMLEventReader(BaseEventReader):
    """Event reader that returns original PTSA Events object with attached
    rawbinwrappers -- objects that know how to read eeg binary data

    Keyword arguments
    -----------------
    filename : str
        path to event file
    eliminate_events_with_no_eeg : bool
        flag to automatically remove events with no eegfile (default True)
    eliminate_nans : bool
        flag to automatically replace nans in the event structs with -999
        (default True)
    eeg_fname_search_pattern : str
        pattern in the eeg filename to search for in order to repalce it with
        eeg_fname_replace_pattern
    eeg_fname_replace_pattern : str
        replace pattern for eeg filename. It will replace all occurrences
        specified by "eeg_fname_replace_pattern"
    normalize_eeg_path : bool
        flag that determines if 'data1', 'data2', etc... in eeg path will get
        converted to 'data'. The flag is False by default meaning all 'data1',
        'data2', etc... are converted to 'data'

    """

    _descriptors = [
        TypeValTuple('eeg_fname_search_pattern', six.string_types, ''),
        TypeValTuple('eeg_fname_replace_pattern', six.string_types, ''),
        TypeValTuple('normalize_eeg_path', bool, False),
    ]

    def __init__(self, **kwds):
        BaseEventReader.__init__(self, **kwds)

        if self.eeg_fname_search_pattern != '' and self.eeg_fname_replace_pattern != '':

            self.alter_eeg_path_flag = True

        else:
            self.alter_eeg_path_flag = False

    def modify_eeg_path(self, events):
        """Replaces search pattern (self.eeg_fname_search_patter') with replace
        pattern (self.eeg_fname_replace_pattern) in every eegfile entry in the
        events recarray

        Parameters
        ----------
        events : np.recarray
            representing events. One of the field of this array should be
            eegfile

        """
        for ev in events:
            ev.eegfile = ev.eegfile.replace(self.eeg_fname_search_pattern,
                                            self.eeg_fname_replace_pattern)
        return events

    def check_reader_settings_for_json_read(self):
        pass
示例#5
0
class MorletWaveletFilterCppLegacy(PropertiedObject, BaseFilter):
    _descriptors = [
        TypeValTuple('freqs', np.ndarray, np.array([], dtype=np.float)),
        TypeValTuple('width', int, 5),
        TypeValTuple('output', str, 'power'),
        TypeValTuple('frequency_dim_pos', int, 0),
        # NOTE in this implementation the default position of frequency is -2
        TypeValTuple('verbose', bool, True),
    ]

    def __init__(self, time_series, **kwds):

        self.window = None
        self.time_series = time_series
        self.init_attrs(kwds)

        # if self.output != 'power':
        #     raise ValueError('Current implementation of '+self.__class__.__name__+' supports wavelet powers only')

    def all_but_time_iterator(self, array):
        from itertools import product
        sizes_except_time = np.asarray(array.shape)[:-1]
        ranges = map(lambda size: range(size), sizes_except_time)
        for cart_prod_idx_tuple in product(*ranges):
            yield cart_prod_idx_tuple, array[cart_prod_idx_tuple]

    def allocate_output_arrays(self, time_axis_size):
        array_type = np.float32
        shape = self.time_series.shape[:-1] + (
            self.freqs.shape[0],
            time_axis_size,
        )

        if self.output == 'power':
            return np.empty(shape=shape, dtype=array_type), None
        elif self.output == 'phase':
            return None, np.empty(shape=shape, dtype=array_type)
        else:
            return np.empty(shape=shape,
                            dtype=array_type), np.empty(shape=shape,
                                                        dtype=array_type)

    def store(self, idx_tuple, target_array, source_array):
        if source_array is None or target_array is None: return

        num_wavelets = self.freqs.shape[0]
        time_axis_size = self.time_series.shape[-1]
        for w in range(num_wavelets):
            out_idx_tuple = idx_tuple + (w, )
            target_array[out_idx_tuple] = source_array[w *
                                                       time_axis_size:(w + 1) *
                                                       time_axis_size]

    def get_data_iterator(self):
        return self.all_but_time_iterator(self.time_series)

    def construct_output_array(self, array, dims, coords):
        out_array = xr.DataArray(array, dims=dims, coords=coords)
        # out_array.attrs['samplerate'] = self.time_series.attrs['samplerate']
        out_array['samplerate'] = self.time_series['samplerate']
        return out_array

    def build_output_arrays(self, wavelet_pow_array, wavelet_phase_array,
                            time_axis):
        wavelet_pow_array_xray = None
        wavelet_phase_array_xray = None

        if isinstance(self.time_series, xr.DataArray):

            dims = list(self.time_series.dims[:-1] + (
                'frequency',
                'time',
            ))

            transposed_dims = []

            # NOTE all computaitons up till this point assume that frequency position is -2 whereas
            # the default setting for this filter sets frequency axis index to 0. To avoid unnecessary transpositions
            # we need to adjust position of the frequency axis in the internal computations

            # getting frequency dim position as positive integer
            self.frequency_dim_pos = (len(dims) +
                                      self.frequency_dim_pos) % len(dims)
            orig_frequency_idx = dims.index('frequency')

            if self.frequency_dim_pos != orig_frequency_idx:
                transposed_dims = dims[:orig_frequency_idx] + dims[
                    orig_frequency_idx + 1:]
                transposed_dims.insert(self.frequency_dim_pos, 'frequency')

            coords = {
                dim_name: self.time_series.coords[dim_name]
                for dim_name in self.time_series.dims[:-1]
            }
            coords['frequency'] = self.freqs
            coords['time'] = time_axis

            if 'offsets' in list(self.time_series.coords.keys()):
                coords['offsets'] = ('time', self.time_series['offsets'])

            if wavelet_pow_array is not None:
                wavelet_pow_array_xray = self.construct_output_array(
                    wavelet_pow_array, dims=dims, coords=coords)
            if wavelet_phase_array is not None:
                wavelet_phase_array_xray = self.construct_output_array(
                    wavelet_phase_array, dims=dims, coords=coords)

            if wavelet_pow_array_xray is not None:
                wavelet_pow_array_xray = TimeSeriesX(wavelet_pow_array_xray)
                if len(transposed_dims):
                    wavelet_pow_array_xray = wavelet_pow_array_xray.transpose(
                        *transposed_dims)

                wavelet_pow_array_xray.attrs = self.time_series.attrs.copy()

            if wavelet_phase_array_xray is not None:
                wavelet_phase_array_xray = TimeSeriesX(
                    wavelet_phase_array_xray)
                if len(transposed_dims):
                    wavelet_phase_array_xray = wavelet_phase_array_xray.transpose(
                        *transposed_dims)

                wavelet_phase_array_xray.attrs = self.time_series.attrs.copy()

            return wavelet_pow_array_xray, wavelet_phase_array_xray

    def filter(self):

        data_iterator = self.get_data_iterator()

        time_axis = self.time_series['time']

        time_axis_size = time_axis.shape[0]
        samplerate = float(self.time_series['samplerate'])

        wavelet_pow_array, wavelet_phase_array = self.allocate_output_arrays(
            time_axis_size=time_axis_size)

        num_wavelets = self.freqs.shape[0]
        # powers=np.empty(shape=(time_axis_size*num_wavelets,), dtype=np.float)

        powers = np.array([], dtype=np.float)
        phases = np.array([], dtype=np.float)
        wavelets_complex_reshaped = np.array([[]], dtype=np.complex)

        if self.output == 'power':
            powers = np.empty(shape=(time_axis_size * num_wavelets, ),
                              dtype=np.float)
        if self.output == 'phase':
            phases = np.empty(shape=(time_axis_size * num_wavelets, ),
                              dtype=np.float)
        if self.output == 'both':
            powers = np.empty(shape=(time_axis_size * num_wavelets, ),
                              dtype=np.float)
            phases = np.empty(shape=(time_axis_size * num_wavelets, ),
                              dtype=np.float)

        morlet_transform = MorletWaveletTransform()
        morlet_transform.init_flex(self.width, self.freqs, samplerate,
                                   time_axis_size)

        wavelet_start = time.time()

        if self.output in ('phase', 'both'):
            for idx_tuple, signal in data_iterator:
                morlet_transform.multiphasevec(signal, powers, phases)
                self.store(idx_tuple, wavelet_pow_array, powers)
                self.store(idx_tuple, wavelet_phase_array, phases)
        elif self.output == 'power':
            for idx_tuple, signal in data_iterator:
                morlet_transform.multiphasevec(signal, powers)
                self.store(idx_tuple, wavelet_pow_array, powers)

        if self.verbose:
            print('total time wavelet loop: ', time.time() - wavelet_start)

        return self.build_output_arrays(wavelet_pow_array, wavelet_phase_array,
                                        time_axis)
示例#6
0
class MonopolarToBipolarMapper(PropertiedObject, BaseFilter):
    """
    Object that takes as an input time series for monopolar electrodes and an array of bipolar pairs and outputs
    Time series where 'channels' axis is replaced by 'bipolar_pairs' axis and the time series data is a difference
    between time series corresponding to different electrodes as specified by bipolar pairs
    """
    _descriptors = [
        TypeValTuple('time_series', TimeSeriesX,
                     TimeSeriesX([0.0], dims=['time'])),
        TypeValTuple(
            'bipolar_pairs', np.recarray,
            np.recarray((0, ), dtype=[('ch0', '|S3'), ('ch1', '|S3')])),
    ]

    def __init__(self, **kwds):
        """
        Constructor:

        :param kwds:allowed values are:
        -------------------------------------
        :param time_series  -  TimeSeriesX object with eeg session data and 'channels as one of the axes'
        :param bipolar_pairs {np.recarray} - an array of bipolar electrode pairs

        :return: None
        """

        self.init_attrs(kwds)

    def filter(self):
        """
        Turns time series for monopolar electrodes into time series where where 'channels' axis is replaced by
        'bipolar_pairs' axis and the time series data is a difference
        between time series corresponding to different electrodes as specified by bipolar pairs

        :return: TimeSeriesX object
        """

        # a = np.arange(20)*2
        #
        # template = [2,4,6,6,8,2,4]
        #
        # sorter = np.argsort(a)
        # idx = sorter[np.searchsorted(a, template, sorter=sorter)]

        # idx = np.where(a == 6)

        #
        # print ch0
        #
        # print ch1
        channel_axis = self.time_series['channels']

        ch0 = self.bipolar_pairs['ch0']
        ch1 = self.bipolar_pairs['ch1']

        sel0 = channel_axis.loc[ch0]
        sel1 = channel_axis.loc[ch1]

        ts0 = self.time_series.loc[dict(channels=sel0)]
        ts1 = self.time_series.loc[dict(channels=sel1)]

        dims_bp = list(self.time_series.dims)
        channels_idx = dims_bp.index('channels')
        dims_bp[channels_idx] = 'bipolar_pairs'

        # coords_bp = [self.time_series[dim_name].copy() for dim_name in self.time_series.dims]
        # coords_bp[channels_idx] = self.bipolar_pairs

        coords_bp = {
            coord_name: coord
            for coord_name, coord in self.time_series.coords.items()
        }
        del coords_bp['channels']
        coords_bp['bipolar_pairs'] = self.bipolar_pairs

        ts = TimeSeriesX(data=ts0.values - ts1.values,
                         dims=dims_bp,
                         coords=coords_bp)
        ts['samplerate'] = self.time_series['samplerate']

        ts.attrs = self.time_series.attrs.copy()
        return ts
示例#7
0
class TimeSeriesEEGReader(PropertiedObject):
    _descriptors = [
        TypeValTuple('samplerate', float, -1.0),
        TypeValTuple('keep_buffer', bool, True),
        TypeValTuple('buffer_time', float, 0.0),
        TypeValTuple('start_time', float, 0.0),
        TypeValTuple('end_time', float, 0.0),
        TypeValTuple('events', np.recarray,
                     np.recarray((0, ), dtype=[('x', int)])),
    ]

    def __init__(self, **kwds):

        self.init_attrs(kwds)

    def __create_bin_readers(self):
        evs = self.events
        eegfiles = np.unique(evs.eegfile)
        raw_bin_wrappers = []
        original_eeg_files = []

        for eegfile in eegfiles:
            events_with_matched_eegfile = evs[evs.eegfile == eegfile]
            ev_with_matched_eegfile = events_with_matched_eegfile[0]
            try:
                # eeg_file_path = join(self.data_dir_prefix, str(pathlib.Path(str(ev_with_matched_eegfile.eegfile)).parts[1:]))
                # raw_bin_wrappers.append(RawBinWrapperXray(eeg_file_path))

                eeg_file_path = ev_with_matched_eegfile.eegfile

                raw_bin_wrappers.append(RawBinWrapperXray(eeg_file_path))

                original_eeg_files.append(eegfile)

                inds = np.where(evs.eegfile == eegfile)[0]

                if self.samplerate < 0.0:
                    data_params = raw_bin_wrappers[-1]._get_params(
                        eeg_file_path)

                    self.samplerate = data_params['samplerate']

            except TypeError:
                print('skipping event with eegfile=', evs.eegfile)
                pass

        raw_bin_wrappers = np.array(raw_bin_wrappers,
                                    dtype=np.dtype(RawBinWrapperXray))

        return raw_bin_wrappers, original_eeg_files

    def get_number_of_samples_for_interval(self, time_interval):
        return int(np.ceil(time_interval * self.samplerate))

    def __compute_time_series_length(self):

        # translate back to dur and offset
        dur = self.end_time - self.start_time
        offset = self.start_time
        buf = self.buffer_time

        # set event durations from rate
        # get the samplesize
        # SHOULD NOT WE CALL IT SAMPLE_INTERVAL???
        samplesize = 1. / self.samplerate

        # get the number of buffer samples
        buf_samp = int(np.ceil(buf / samplesize))

        # calculate the offset samples that contains the desired offset
        offset_samp = int(
            np.ceil((np.abs(offset) - samplesize * .5) / samplesize) *
            np.sign(offset))

        # finally get the duration necessary to cover the desired span
        #dur_samp = int(np.ceil((dur - samplesize*.5)/samplesize))
        dur_samp = (int(np.ceil(
            (dur + offset - samplesize * .5) / samplesize)) - offset_samp + 1)

        # add in the buffer
        dur_samp += 2 * buf_samp
        offset_samp -= buf_samp

        return dur_samp

    def read(self, channels):
        evs = self.events

        raw_bin_wrappers, original_eeg_files = self.__create_bin_readers()

        # we need to create rawbinwrappers first to figure out sample rate before calling __compute_time_series_length()
        time_series_length = self.__compute_time_series_length()

        time_series_data = np.empty(
            (len(channels), len(evs), time_series_length),
            dtype=np.float) * np.nan

        ordered_indices = np.arange(len(evs))
        event_indices_list = []
        events = []
        newdat_list = []
        eventdata = None

        for s, (src,
                eegfile) in enumerate(zip(raw_bin_wrappers,
                                          original_eeg_files)):
            ind = np.atleast_1d(evs.eegfile == eegfile)

            event_indices_list.append(ordered_indices[ind])

            if len(ind) == 1:
                event_offsets = evs['eegoffset']
                events.append(evs)
            else:
                event_offsets = evs[ind]['eegoffset']
                events.append(evs[ind])

            # get the timeseries for those events
            newdat = src.get_event_data_xray(channels,
                                             event_offsets,
                                             self.start_time,
                                             self.end_time,
                                             self.buffer_time,
                                             resampled_rate=None,
                                             filt_freq=None,
                                             filt_type=None,
                                             filt_order=None,
                                             keep_buffer=self.keep_buffer,
                                             loop_axis=None,
                                             num_mp_procs=0,
                                             eoffset='eegoffset',
                                             eoffset_in_time=False)
            newdat_list.append(newdat)

        event_indices_array = np.hstack(event_indices_list)
        event_indices_restore_sort_order_array = event_indices_array.argsort()

        start_extend_time = time.time()

        #new code
        eventdata = xr.concat(newdat_list, dim='events')
        end_extend_time = time.time()

        # concatenate (must eventually check that dims match)
        # ORIGINAL CODE
        tdim = eventdata['time']
        cdim = eventdata['channels']
        # srate = eventdata.samplerate
        srate = eventdata.attrs['samplerate']
        events = np.concatenate(events).view(Events)

        eventdata_xray = xr.DataArray(eventdata.values,
                                      coords=[cdim, events, tdim],
                                      dims=['channels', 'events', 'time'])
        eventdata_xray.attrs['samplerate'] = eventdata.attrs['samplerate']
        eventdata_xray = eventdata_xray[:,
                                        event_indices_restore_sort_order_array, :]  #### RESTORE THIS

        if not self.keep_buffer:
            # trimming buffer data samples
            number_of_buffer_samples = self.get_number_of_samples_for_interval(
                self.buffer_time)
            if number_of_buffer_samples > 0:
                eventdata_xray = eventdata_xray[:, :, number_of_buffer_samples:
                                                -number_of_buffer_samples]

        return TimeSeriesX(eventdata_xray)

    def read_all(self, channels, start_offset, end_offset, buffer):
        evs = self.events

        raw_bin_wrappers, original_eeg_files = self.__create_bin_readers()

        # we need to create rawbinwrappers first to figure out sample rate before calling __compute_time_series_length()
        time_series_length = self.__compute_time_series_length()

        time_series_data = np.empty(
            (len(channels), len(evs), time_series_length),
            dtype=np.float) * np.nan

        events = []

        newdat_list = []

        # for s,src in enumerate(usources):
        for s, (src,
                eegfile) in enumerate(zip(raw_bin_wrappers,
                                          original_eeg_files)):
            ind = np.atleast_1d(evs.eegfile == eegfile)

            if len(ind) == 1:
                events.append(evs[0])
            else:
                events.append(evs[ind])

            # print event_offsets
            #print "Loading %d events from %s" % (ind.sum(),src)
            # get the timeseries for those events
            newdat = src.get_event_data_xray_simple(channels=channels,
                                                    events=events,
                                                    start_offset=start_offset,
                                                    end_offset=end_offset,
                                                    buffer=buffer)

            newdat_list.append(newdat)

        start_extend_time = time.time()
        #new code
        eventdata = xr.concat(newdat_list, dim='events')
        end_extend_time = time.time()

        # concatenate (must eventually check that dims match)
        # ORIGINAL CODE
        tdim = eventdata['time']
        cdim = eventdata['channels']
        # srate = eventdata.samplerate
        srate = eventdata.attrs['samplerate']

        eventdata_xray = eventdata

        if not self.keep_buffer:
            # trimming buffer data samples
            number_of_buffer_samples = self.get_number_of_samples_for_interval(
                self.buffer_time)
            if number_of_buffer_samples > 0:
                eventdata_xray = eventdata_xray[:, :, number_of_buffer_samples:
                                                -number_of_buffer_samples]

        return eventdata_xray
示例#8
0
class PTSAEventReader(BaseEventReader,BaseReader):
    '''
    Event reader that returns original PTSA Events object with attached rawbinwrappers
    rawbinwrappers are objects that know how to read eeg binary data
    '''

    _descriptors = [

        TypeValTuple('attach_rawbinwrapper', bool, True),
        TypeValTuple('use_groupped_rawbinwrapper', bool, True),

    ]

    def __init__(self, **kwds):
        '''
        Constructor:

        :param kwds:allowed values are:
        -------------------------------------
        :param filename {str} -  path to event file
        :param eliminate_events_with_no_eeg {bool} - flag to automatically remov events woth no eegfile (default True)
        :param use_reref_eeg {bool} -  flag that changes eegfiles to point reref eegs. Default is False and eegs read
        are nonreref ones
        :param attach_rawbinwrapper {bool} - flag signaling whether to attach rawbinwrappers to Event obj or not.
        Default is True
        :param use_groupped_rawbinwrapper {bool} - flag signaling whether to use groupped rawbinwrappers
        (i.e. shared between many events with same eeglile) or not. Default is True. When True data reads are much faster
        :return: None
        '''
        BaseEventReader.__init__(self, **kwds)

    def read(self):
        '''
        Reads Matlab event file , converts it to np.recarray and attaches rawbinwrappers (if appropriate flags indicate so)
        :return: Events object. depending on flagg settings the rawbinwrappers may be attached as well
        '''

        # calling base class read fcn
        evs = BaseEventReader.read(self)

        # in case evs is simply recarray
        if not isinstance(evs, Events):
            evs = Events(evs)

        if self.attach_rawbinwrapper:
            evs = evs.add_fields(esrc=np.dtype(RawBinWrapper))

            if self.use_groupped_rawbinwrapper:  # this should be default choice - much faster execution
                self.attach_rawbinwrapper_groupped(evs)
            else:  # used for debugging purposes
                self.attach_rawbinwrapper_individual(evs)

        return evs

    def attach_rawbinwrapper_groupped(self, evs):
        '''
        attaches raw bin wrappers to individual records. Single rawbinwrapper is shared between events that have same
        eegfile
        :param evs: Events object
        :return: Events object with attached rawbinarappers
        '''

        eegfiles = np.unique(evs.eegfile)

        for eegfile in eegfiles:

            raw_bin_wrapper = RawBinWrapper(eegfile)
            inds = np.where(evs.eegfile == eegfile)[0]
            for i in inds:
                evs[i]['esrc'] = raw_bin_wrapper

    def attach_rawbinwrapper_individual(self, evs):
        '''
        attaches raw bin wrappers to individual records. Uses separate rawbinwrapper for each record
        :param evs: Events object
        :return: Events object with attached rawbinarappers
        '''

        for ev in evs:
            try:
                if self.attach_rawbinwrapper:
                    ev.esrc = RawBinWrapper(ev.eegfile)
            except TypeError:
                print 'skipping event with eegfile=', ev.eegfile
                pass
示例#9
0
class BaseRawReader(PropertiedObject, BaseReader):
    """
    Object that knows how to read binary eeg files
    """
    _descriptors = [
        TypeValTuple('dataroot', six.string_types, ''),
        # TypeValTuple('channels', list, []),
        TypeValTuple('channels', np.ndarray, np.array([], dtype='|S3')),
        TypeValTuple('start_offsets', np.ndarray, np.array([0], dtype=np.int)),
        TypeValTuple('read_size', int, -1),
    ]

    def __init__(self, **kwds):
        """
        Constructor
        :param kwds:allowed values are:
        -------------------------------------
        :param dataroot {str} -  core name of the eegfile file (i.e. full path except extension e.g. '.002').
        Normally this is eegfile field from events record

        :param channels {list} - list of channels (list of strings) that should be read
        :param start_offsets {ndarray} -  array of ints with read offsets
        :param read_size {int} - size of the read chunk. If -1 the entire file is read
        --------------------------------------
        :return:None
        """

        self.init_attrs(kwds)

        FileFormat = namedtuple('FileFormat', ['data_size', 'format_string'])
        self.file_format_dict = {
            'single': FileFormat(data_size=4, format_string='f'),
            'float32':FileFormat(data_size=4, format_string='f'),
            'short': FileFormat(data_size=2, format_string='h'),
            'int16': FileFormat(data_size=2, format_string='h'),
            'int32': FileFormat(data_size=4, format_string='i'),
            'double': FileFormat(data_size=8, format_string='d'),
            'float64':FileFormat(data_size=8, format_string='d')
        }

        self.file_format = self.file_format_dict['int16']
        if isinstance(self.dataroot, six.binary_type):
            self.dataroot = self.dataroot.decode()

        p_reader = ParamsReader(dataroot=self.dataroot)
        self.params_dict = p_reader.read()

        try:
            format_name = self.params_dict['format']
            try:
                self.file_format = self.file_format_dict[format_name]
            except KeyError:
                raise RuntimeError('Unsupported format: %s. Allowed format names are: %s' % (
                    format_name, list(self.file_format_dict.keys())))
        except KeyError:
            warnings.warn('Could not find data format definition in the params file. Will read the file assuming' \
                          ' data format is int16', RuntimeWarning)

    def get_file_size(self):
        """

        :return: {int} size of the files whose core name (dataroot) matches self.dataroot. Assumes ALL files with this
        dataroot are of the same length and uses first channel to determin the common file length
        """
        if isinstance(self.channels[0], six.binary_type):
            ch = self.channels[0].decode()
        else:
            ch = self.channels[0]
        eegfname = self.dataroot + '.' + ch
        return os.path.getsize(eegfname)

    def read(self):
        """

        :return: DataArray objects populated with data read from eeg files. The size of the output is
        number of channels x number of start offsets x number of time series points
        The corresponding DataArray axes are: 'channels', 'start_offsets', 'offsets'

        """

        if self.read_size < 0:
            self.read_size = int(self.get_file_size() / self.file_format.data_size)

        # allocate space for data
        eventdata = np.empty((len(self.channels), len(self.start_offsets), self.read_size),
                             dtype=np.float) * np.nan

        read_ok_mask = np.ones(shape=(len(self.channels), len(self.start_offsets)), dtype=np.bool)

        # loop over channels
        for c, channel in enumerate(self.channels):
            try:
                eegfname = self.dataroot + '.' + channel
            except TypeError:
                eegfname = self.dataroot + '.' + channel.decode()

            with open(eegfname, 'rb') as efile:
                # loop over start offsets
                for e, start_offset in enumerate(self.start_offsets):
                    # rejecting negative offset
                    if start_offset < 0:
                        read_ok_mask[c, e] = False
                        print(('Cannot read from negative offset %d in file %s' % (start_offset, eegfname)))
                        continue

                    # seek to the position in the file
                    efile.seek(self.file_format.data_size * start_offset, 0)

                    # read the data
                    data = efile.read(int(self.file_format.data_size * self.read_size))

                    # convert from string to array based on the format
                    # hard-codes little endian
                    fmt = '<' + str(int(len(data) / self.file_format.data_size)) + self.file_format.format_string
                    data = np.array(struct.unpack(fmt, data))

                    # make sure we got some data
                    if len(data) < self.read_size:
                        read_ok_mask[c, e] = False

                        print((
                            'Cannot read full chunk of data for offset ' + str(start_offset) +
                            'End of read interval  is outside the bounds of file ' + str(eegfname)))
                    else:
                        # append it to the eventdata
                        eventdata[c, e, :] = data

        # multiply by the gain
        eventdata *= self.params_dict['gain']

        eventdata = DataArray(eventdata,
                              dims=['channels', 'start_offsets', 'offsets'],
                              coords={
                                  'channels': self.channels,
                                  'start_offsets': self.start_offsets.copy(),
                                  'offsets': np.arange(self.read_size),
                                  'samplerate': self.params_dict['samplerate']

                              }
                              )

        from copy import deepcopy
        eventdata.attrs = deepcopy(self.params_dict)

        return eventdata, read_ok_mask
示例#10
0
class DataChopper(PropertiedObject,BaseFilter):
    """
    EventDataChopper converts continuous time series of entire session into chunks based on the events specification
    In other words you may read entire eeg session first and then using EventDataChopper
    divide it into chunks corresponding to events of your choice
    """
    _descriptors = [

        TypeValTuple('start_time', float, 0.0),
        TypeValTuple('end_time', float, 0.0),
        TypeValTuple('buffer_time', float, 0.0),
        TypeValTuple('events', np.recarray, np.recarray((1,), dtype=[('x', int)])),
        TypeValTuple('start_offsets', np.ndarray, np.array([], dtype=int)),
        TypeValTuple('session_data', TimeSeriesX, TimeSeriesX([0.0], dims=['time'])),
    ]

    def __init__(self, **kwds):
        """
        Constructor:

        :param kwds:allowed values are:
        -------------------------------------
        :param start_time {float} -  read start offset in seconds w.r.t to the eegeffset specified in the events recarray
        :param end_time {float} -  read end offset in seconds w.r.t to the eegeffset specified in the events recarray
        :param end_time {float} -  extra buffer in seconds (subtracted from start read and added to end read)
        :param events {np.recarray} - numpy recarray representing events
        :param startoffsets {np.ndarray} - numpy array with offsets at which chopping should take place
        :param session_datar {str} -  TimeSeriesX object with eeg session data

        :return: None
        """

        self.init_attrs(kwds)

    def get_event_chunk_size_and_start_point_shift(self, eegoffset, samplerate, offset_time_array):
        """
        Computes number of time points for each event and read offset w.r.t. event's eegoffset
        :param ev: record representing single event
        :param samplerate: samplerate fo the time series
        :param offset_time_array: "offsets" axis of the DataArray returned by EEGReader. This is the axis that represents
        time axis but instead of beind dimensioned to seconds it simply represents position of a given data point in a series
        The time axis is constructed by dividint offsets axis by the samplerate
        :return: event's read chunk size {int}, read offset w.r.t. to event's eegoffset {}
        """
        # figuring out read size chunk and shift w.r.t to eegoffset. We need this fcn in case we pass resampled session data

        original_samplerate = float((offset_time_array[-1] - offset_time_array[0])) / offset_time_array.shape[
            0] * samplerate


        start_point = eegoffset - int(np.ceil((self.buffer_time - self.start_time) * original_samplerate))
        end_point = eegoffset + int(
            np.ceil((self.end_time + self.buffer_time) * original_samplerate))

        selector_array = np.where((offset_time_array >= start_point) & (offset_time_array < end_point))[0]
        start_point_shift = selector_array[0] - np.where((offset_time_array >= eegoffset))[0][0]

        return len(selector_array), start_point_shift


    def filter(self):
        """
        Chops session into chunks orresponding to events
        :return: timeSeriesX object with chopped session
        """
        chop_on_start_offsets_flag = bool(len(self.start_offsets))

        if chop_on_start_offsets_flag:

            start_offsets = self.start_offsets
            chopping_axis_name = 'start_offsets'
            chopping_axis_data = start_offsets
        else:

            evs = self.events[self.events.eegfile == self.session_data.attrs['dataroot']]
            start_offsets = evs.eegoffset
            chopping_axis_name = 'events'
            chopping_axis_data = evs


        # samplerate = self.session_data.attrs['samplerate']
        samplerate = float(self.session_data['samplerate'])
        offset_time_array = self.session_data['offsets']

        event_chunk_size, start_point_shift = self.get_event_chunk_size_and_start_point_shift(
        eegoffset=start_offsets[0],
        samplerate=samplerate,
        offset_time_array=offset_time_array)


        event_time_axis = np.arange(event_chunk_size)*(1.0/samplerate)+(self.start_time-self.buffer_time)

        data_list = []

        for i, eegoffset in enumerate(start_offsets):

            start_chop_pos = np.where(offset_time_array >= eegoffset)[0][0]
            start_chop_pos += start_point_shift
            selector_array = np.arange(start=start_chop_pos, stop=start_chop_pos + event_chunk_size)

            chopped_data_array = self.session_data.isel(time=selector_array)

            chopped_data_array['time'] = event_time_axis
            chopped_data_array['start_offsets'] = [i]

            data_list.append(chopped_data_array)

        ev_concat_data = xr.concat(data_list, dim='start_offsets')


        ev_concat_data = ev_concat_data.rename({'start_offsets':chopping_axis_name})
        ev_concat_data[chopping_axis_name] = chopping_axis_data

        # ev_concat_data.attrs['samplerate'] = samplerate
        ev_concat_data['samplerate'] = samplerate
        ev_concat_data.attrs['start_time'] = self.start_time
        ev_concat_data.attrs['end_time'] = self.end_time
        ev_concat_data.attrs['buffer_time'] = self.buffer_time
        return TimeSeriesX(ev_concat_data)
示例#11
0
class TalReader(PropertiedObject, BaseReader):
    """
    Reader that reads tal structs Matlab file and converts it to numpy recarray
    """
    _descriptors = [
        TypeValTuple('filename', six.string_types, ''),
        TypeValTuple('struct_name', six.string_types, 'bpTalStruct'),
        TypeValTuple('struct_type', six.string_types, 'bi')
    ]

    def __init__(self, **kwds):
        """
        Keyword arguments
        -----------------

        :param filename {str} -  path to tal file or pairs.json file
        :param struct_name {str} -  name of the matlab struct to load
        :param struct_type {str} - either 'mono', indicating a monopolar struct, or 'bi', indicating a bipolar struct
        :return: None

        """

        self.init_attrs(kwds)
        self.bipolar_channels = None

        self.tal_struct_array = None
        self._json = os.path.splitext(self.filename)[-1] == '.json'
        if self.struct_type not in ['bi', 'mono']:
            raise AttributeError(
                'Value %s not a valid struct_type. Please choose either "mono" or "bi"'
                % self.struct_type)
        if self.struct_type == 'mono':
            self.struct_name = 'talStruct'

    def get_bipolar_pairs(self):
        """

        :return: numpy recarray where each record has two fields 'ch0' and 'ch1' storing  channel labels.
        """
        if self.bipolar_channels is None:
            if self.tal_struct_array is None:
                self.read()
            self.initialize_bipolar_pairs()

        return self.bipolar_channels

    def get_monopolar_channels(self):
        """

        :return: numpy array of monopolar channel labels
        """
        if self.struct_type == 'bi':
            bipolar_array = self.get_bipolar_pairs()
            monopolar_set = set(
                list(bipolar_array['ch0']) + list(bipolar_array['ch1']))
            return np.array(sorted(list(monopolar_set)))
        else:
            if self.tal_struct_array is None:
                self.read()
            return np.array(
                ['{:03d}'.format(c) for c in self.tal_struct_array['channel']])

    def initialize_bipolar_pairs(self):
        # initialize bipolar pairs
        self.bipolar_channels = np.recarray(shape=(len(self.tal_struct_array)),
                                            dtype=[('ch0', '|S3'),
                                                   ('ch1', '|S3')])

        channel_record_array = self.tal_struct_array['channel']
        for i, channel_array in enumerate(channel_record_array):
            self.bipolar_channels[i] = tuple(
                map(lambda x: str(x).zfill(3), channel_array))

    def from_dict(self, pairs):
        if self.struct_type == 'bi':
            pairs = pd.DataFrame.from_dict(
                list(pairs.values())[0]['pairs'],
                orient='index').sort_values(by=['channel_1', 'channel_2'])
            pairs.index.name = 'tagName'
            pairs['channel'] = [[ch1, ch2] for ch1, ch2 in zip(
                pairs.channel_1.values, pairs.channel_2.values)]
            pairs['eType'] = pairs.type_1
            return pairs.to_records()
        elif self.struct_type == 'mono':
            contacts = pd.DataFrame.from_dict(
                list(pairs.values())[0]['contacts'],
                orient='index').sort_values(by='channel')
            contacts.index.name = 'tagName'
            return contacts.to_records()

    def read(self):
        """

        :return: np.recarray representing tal struct array (originally defined in Matlab file)
        """
        if not self._json:
            from ptsa.data.MatlabIO import read_single_matlab_matrix_as_numpy_structured_array

            struct_names = ['bpTalStruct', 'subjTalEvents']
            # struct_names = ['bpTalStruct']
            if self.struct_name not in struct_names:
                self.tal_struct_array = read_single_matlab_matrix_as_numpy_structured_array(
                    self.filename, self.struct_name, verbose=False)
                if self.tal_struct_array is not None:
                    return self.tal_struct_array
                else:
                    raise AttributeError(
                        'Could not read tal struct data for the specified struct_name='
                        + self.struct_name)

            else:

                for sn in struct_names:
                    self.tal_struct_array = read_single_matlab_matrix_as_numpy_structured_array(
                        self.filename, sn, verbose=False)
                    if self.tal_struct_array is not None:
                        return self.tal_struct_array
        else:
            with open(self.filename) as fp:
                pairs = json.load(fp)
            self.tal_struct_array = self.from_dict(pairs)
            return self.tal_struct_array

        raise AttributeError(
            'Could not read tal struct data. Try specifying struct_name argument :'
            '\nTalReader(filename=e_path, struct_name=<name_of_struc_to_read>)'
        )
示例#12
0
class EventDataChopper(PropertiedObject):
    _descriptors = [
        TypeValTuple('time_shift', float, 0.0),
        TypeValTuple('event_duration', float, 0.0),
        TypeValTuple('buffer', float, 1.0),
        TypeValTuple('events', np.recarray,
                     np.recarray((1, ), dtype=[('x', int)])),
        TypeValTuple('data_dict', dict, {}),
    ]

    def __init__(self, **kwds):

        self.init_attrs(kwds)

    def get_event_chunk_size_and_start_point_shift(self, ev, samplerate,
                                                   offset_time_array):
        # figuring out read size chunk and shift w.r.t to eegoffset

        original_samplerate = float(
            (offset_time_array[-1] -
             offset_time_array[0])) / offset_time_array.shape[0] * samplerate

        start_point = ev.eegoffset - int(
            np.ceil((self.buffer + self.time_shift) * original_samplerate))
        end_point = ev.eegoffset + int(
            np.ceil((self.event_duration + self.buffer + self.time_shift) *
                    original_samplerate))

        selector_array = np.where((offset_time_array >= start_point)
                                  & (offset_time_array < end_point))[0]
        start_point_shift = selector_array[0] - np.where(
            (offset_time_array >= ev.eegoffset))[0][0]

        return len(selector_array), start_point_shift

    def filter(self):

        event_data_dict = OrderedDict()

        for eegfile_name, data in list(self.data_dict.items()):

            evs = self.events[self.events.eegfile == eegfile_name]

            samplerate = data.attrs['samplerate']

            # used in constructing time_axis
            offset_time_array = data['time'].values['eegoffset']

            event_chunk_size, start_point_shift = self.get_event_chunk_size_and_start_point_shift(
                ev=evs[0],
                samplerate=samplerate,
                offset_time_array=offset_time_array)

            event_time_axis = np.linspace(
                -self.buffer + self.time_shift,
                self.event_duration + self.buffer + self.time_shift,
                event_chunk_size)

            data_list = []

            shape = None

            for i, ev in enumerate(evs):
                # print ev.eegoffset
                start_chop_pos = np.where(
                    offset_time_array >= ev.eegoffset)[0][0]
                start_chop_pos += start_point_shift
                selector_array = np.arange(start=start_chop_pos,
                                           stop=start_chop_pos +
                                           event_chunk_size)

                # ev_array = eeg_session_data[:,:,selector_array] # ORIG CODE

                chopped_data_array = data.isel(time=selector_array)

                chopped_data_array['time'] = event_time_axis
                chopped_data_array['events'] = [i]

                data_list.append(chopped_data_array)

                # print i

            ev_concat_data = xr.concat(data_list, dim='events')

            # replacing simple events axis (consecutive integers) with recarray of events
            ev_concat_data['events'] = evs

            ev_concat_data.attrs['samplerate'] = samplerate
            ev_concat_data.attrs['time_shift'] = self.time_shift
            ev_concat_data.attrs['event_duration'] = self.event_duration
            ev_concat_data.attrs['buffer'] = self.buffer

            event_data_dict[eegfile_name] = TimeSeriesX(ev_concat_data)

            break  # REMOVE THIS

        return event_data_dict
class TimeSeriesSessionEEGReader(PropertiedObject):
    _descriptors = [
        TypeValTuple('samplerate', float, -1.0),
        TypeValTuple('channels', np.ndarray, np.array([],dtype='|S3')),
        TypeValTuple('offset', int, 0),
        TypeValTuple('event_data_only', bool, False),
        TypeValTuple('default_buffer', float, 10.0),
        TypeValTuple('events', np.recarray, np.recarray((1,), dtype=[('x', int)])),

    ]

    def __init__(self, **kwds):

        for option_name, val in kwds.items():

            try:
                attr = getattr(self, option_name)
                setattr(self, option_name, val)
            except AttributeError:
                print 'Option: ' + option_name + ' is not allowed'

        self.eegfile_names = self._extract_session_eegfile_names()

        self._extract_samplerate(self.eegfile_names[0])

        self._create_bin_readers(self.eegfile_names)

        self.bin_readers_dict = self._create_bin_readers(self.eegfile_names)

    def get_session_eegfile_names(self):
        return self.eegfile_names

    def get_number_of_sessions(self):
        return self.eegfile_names

    def _extract_session_eegfile_names(self):

        # sorting file names in the order in which they appear in the file
        evs = self.events
        eegfile_names = np.unique(evs.eegfile)
        eeg_file_names_sorter = np.zeros(len(eegfile_names), dtype=np.int)

        for i, eegfile_name in enumerate(eegfile_names):
            eeg_file_names_sorter[i] = np.where(evs.eegfile == eegfile_name)[0][0]

        eeg_file_names_sorter = np.argsort(eeg_file_names_sorter)

        eegfile_names = eegfile_names[eeg_file_names_sorter]

        # print eegfile_names

        return eegfile_names

    def _extract_samplerate(self, eegfile_name):
        rbw_xray = RawBinWrapperXray(eegfile_name)
        data_params = rbw_xray._get_params(eegfile_name)
        self.samplerate = data_params['samplerate']

    def _create_bin_readers(self, eegfile_names):

        bin_readers_dict = OrderedDict()

        for eegfile_name in eegfile_names:
            try:
                bin_readers_dict[eegfile_name] = RawBinWrapperXray(eegfile_name)
            except TypeError:
                warning_str = 'Could not create reader for %s' % eegfile_name
                print warning_str
                raise TypeError(warning_str)

        return bin_readers_dict

    def determine_read_offset_range(self, eegfile_name):
        # determine events that have matching eegfile_name
        evs = self.events[self.events.eegfile == eegfile_name]

        start_offset = evs[0].eegoffset - int(np.ceil(self.default_buffer*self.samplerate))
        if start_offset<0:
            start_offset

        end_offset = evs[-1].eegoffset + int(np.ceil(self.default_buffer*self.samplerate))

        return start_offset,end_offset

    def read_session(self, eegfile_name):
        samplesize = 1.0 / self.samplerate

        bin_reader = self.bin_readers_dict[eegfile_name]

        print 'reading ', eegfile_name

        start_offset = self.offset
        end_offset = -1
        if self.event_data_only:
            #reading continuous data containig events and small buffer
            start_offset, end_offset = self.determine_read_offset_range(eegfile_name)

        eegdata = bin_reader._load_all_data(channels=self.channels, start_offset=start_offset, end_offset=end_offset)


        # constructing time exis as record array [(session_time_in_sec,offset)]

        number_of_time_points = eegdata.shape[2]
        start_time = start_offset * samplesize
        end_time = start_time + number_of_time_points * samplesize

        time_range = np.linspace(start_time, end_time, number_of_time_points)
        eegoffset = np.arange(start_offset, start_offset+ number_of_time_points)

        time_axis = np.rec.fromarrays([time_range, eegoffset], names='time,eegoffset')

        # constructing xray Data Array with session eeg data - note we are adding event dimension to simplify
        # chopping of the data sample into events - single events will be concatenated allong events axis
        eegdata_xray = xray.DataArray(eegdata, coords=[self.channels, np.arange(1), time_axis],
                                      dims=['channels', 'events', 'time'])
        eegdata_xray.attrs['samplerate'] = self.samplerate

        print 'last_time_stamp=',eegdata_xray['time'][-1]

        return TimeSeriesX(eegdata_xray)

    def read(self, session_list=[]):

        session_eegdata_dict = OrderedDict()
        samplesize = 1.0 / self.samplerate

        eegfile_names =  self.bin_readers_dict.keys() if len(session_list)==0 else session_list
        for eegfile_name in eegfile_names:
            eegdata_xray = self.read_session(eegfile_name)
            session_eegdata_dict[eegfile_name] = eegdata_xray

        return session_eegdata_dict
示例#14
0
class EEGReader(PropertiedObject, BaseReader):
    '''
    Reader that knows how to read binary eeg files. It can read chunks of the eeg signal based on events input
    or can read entire session if session_dataroot is non empty
    '''
    _descriptors = [
        TypeValTuple('channels', np.ndarray, np.array([], dtype='|S3')),
        TypeValTuple('start_time', float, 0.0),
        TypeValTuple('end_time', float, 0.0),
        TypeValTuple('buffer_time', float, 0.0),
        TypeValTuple('events', np.recarray,
                     np.recarray((0, ), dtype=[('x', int)])),
        TypeValTuple('session_dataroot', str, ''),
    ]

    def __init__(self, **kwds):
        '''
        Constructor
        :param kwds:allowed values are:
        -------------------------------------
        :param channels {np.ndarray} -  numpy array of channel labels
        :param start_time {float} -  read start offset in seconds w.r.t to the eegeffset specified in the events recarray
        :param end_time {float} -  read end offset in seconds w.r.t to the eegeffset specified in the events recarray
        :param end_time {float} -  extra buffer in seconds (subtracted from start read and added to end read)
        :param events {np.recarray} - numpy recarray representing Events
        :param session_dataroot {str} -  path to session dataroot. When set the reader will read the entire session

        :return:None
        '''
        self.init_attrs(kwds)
        self.removed_corrupt_events = False
        self.event_ok_mask_sorted = None

        assert self.start_time <= self.end_time, \
            'start_time (%s) must be less or equal to end_time(%s) ' % (self.start_time, self.end_time)

        self.read_fcn = self.read_events_data
        if self.session_dataroot:
            self.read_fcn = self.read_session_data

    def compute_read_offsets(self, dataroot):
        '''
        Reads Parameter file and exracts sampling rate that is used to convert from start_time, end_time, buffer_time
        (expressed in seconds)
        to start_offset, end_offset, buffer_offset expressed as integers indicating number of time series data points (not bytes!)

        :param dataroot: core name of the eeg datafile
        :return: tuple of 3 {int} - start_offset, end_offset, buffer_offset
        '''
        p_reader = ParamsReader(dataroot=dataroot)
        params = p_reader.read()
        samplerate = params['samplerate']
        # start_offset = int(np.ceil(self.start_time * samplerate))
        # end_offset = int(np.ceil(self.end_time * samplerate))
        # buffer_offset = int(np.ceil(self.buffer_time * samplerate))

        start_offset = int(np.round(self.start_time * samplerate))
        end_offset = int(np.round(self.end_time * samplerate))
        buffer_offset = int(np.round(self.buffer_time * samplerate))

        return start_offset, end_offset, buffer_offset

    def __create_base_raw_readers(self):
        '''
        Creates BaseRawreader for each (unique) dataroot present in events recarray
        :return: list of BaseRawReaders and list of dataroots
        '''
        evs = self.events
        dataroots = np.unique(evs.eegfile)
        raw_readers = []
        original_dataroots = []

        for dataroot in dataroots:
            events_with_matched_dataroot = evs[evs.eegfile == dataroot]

            start_offset, end_offset, buffer_offset = self.compute_read_offsets(
                dataroot=dataroot)

            read_size = end_offset - start_offset + 2 * buffer_offset

            # start_offsets = events_with_matched_dataroot.eegoffset + start_offset - buffer_offset
            start_offsets = events_with_matched_dataroot.eegoffset + start_offset - buffer_offset

            brr = BaseRawReader(dataroot=dataroot,
                                channels=self.channels,
                                start_offsets=start_offsets,
                                read_size=read_size)
            raw_readers.append(brr)

            original_dataroots.append(dataroot)

        return raw_readers, original_dataroots

    def read_session_data(self):
        '''
        Reads entire session worth of data
        :return: TimeSeriesX object (channels x events x time) with data for entire session the events dimension has length 1
        '''
        brr = BaseRawReader(dataroot=self.session_dataroot,
                            channels=self.channels)
        session_array, read_ok_mask = brr.read()

        offsets_axis = session_array['offsets']
        number_of_time_points = offsets_axis.shape[0]
        samplerate = float(session_array['samplerate'])
        physical_time_array = np.arange(number_of_time_points) * (1.0 /
                                                                  samplerate)

        # session_array = session_array.rename({'start_offsets': 'events'})

        session_time_series = TimeSeriesX(
            session_array.values,
            dims=['channels', 'start_offsets', 'time'],
            coords={
                'channels': session_array['channels'],
                'start_offsets': session_array['start_offsets'],
                'time': physical_time_array,
                'offsets': ('time', session_array['offsets']),
                'samplerate': session_array['samplerate']
                # 'dataroot':self.session_dataroot
            })
        session_time_series.attrs = session_array.attrs.copy()
        session_time_series.attrs['dataroot'] = self.session_dataroot

        return session_time_series

    def removed_bad_data(self):
        return self.removed_corrupt_events

    def get_event_ok_mask(self):
        return self.event_ok_mask_sorted

    def read_events_data(self):
        '''
        Reads eeg data for individual event
        :return: TimeSeriesX  object (channels x events x time) with data for individual events
        '''
        self.event_ok_mask_sorted = None  # reset self.event_ok_mask_sorted

        evs = self.events

        raw_readers, original_dataroots = self.__create_base_raw_readers()

        # used for restoring original order of the events
        ordered_indices = np.arange(len(evs))
        event_indices_list = []
        events = []

        ts_array_list = []

        event_ok_mask_list = []

        for s, (raw_reader,
                dataroot) in enumerate(zip(raw_readers, original_dataroots)):

            ts_array, read_ok_mask = raw_reader.read()

            event_ok_mask_list.append(np.all(read_ok_mask, axis=0))

            ind = np.atleast_1d(evs.eegfile == dataroot)
            event_indices_list.append(ordered_indices[ind])
            events.append(evs[ind])

            ts_array_list.append(ts_array)

        event_indices_array = np.hstack(event_indices_list)

        event_indices_restore_sort_order_array = event_indices_array.argsort()

        start_extend_time = time.time()
        # new code
        eventdata = xr.concat(ts_array_list, dim='start_offsets')
        # tdim = np.linspace(self.start_time-self.buffer_time,self.end_time+self.buffer_time,num=eventdata['offsets'].shape[0])
        # samplerate=eventdata.attrs['samplerate'].data
        samplerate = float(eventdata['samplerate'])
        tdim = np.arange(eventdata.shape[-1]) * (1.0 / samplerate) + (
            self.start_time - self.buffer_time)
        cdim = eventdata['channels']
        edim = np.concatenate(events).view(np.recarray).copy()

        attrs = eventdata.attrs.copy()
        # constructing TimeSeries Object
        # eventdata = TimeSeriesX(eventdata.data,dims=['channels','events','time'],coords=[cdim,edim,tdim])
        eventdata = TimeSeriesX(eventdata.data,
                                dims=['channels', 'events', 'time'],
                                coords={
                                    'channels': cdim,
                                    'events': edim,
                                    'time': tdim,
                                    'samplerate': samplerate
                                })

        eventdata.attrs = attrs

        # restoring original order of the events
        eventdata = eventdata[:, event_indices_restore_sort_order_array, :]

        event_ok_mask = np.hstack(event_ok_mask_list)
        event_ok_mask_sorted = event_ok_mask[
            event_indices_restore_sort_order_array]
        #removing bad events
        if np.any(~event_ok_mask_sorted):
            self.removed_corrupt_events = True
            self.event_ok_mask_sorted = event_ok_mask_sorted

        eventdata = eventdata[:, event_ok_mask_sorted, :]

        return eventdata

    # def read_events_data(self):
    #     '''
    #     Reads eeg data for individual event
    #     :return: TimeSeriesX  object (channels x events x time) with data for individual events
    #     '''
    #     evs = self.events
    #
    #     raw_readers, original_dataroots = self.__create_base_raw_readers()
    #
    #     # used for restoring original order of the events
    #     ordered_indices = np.arange(len(evs))
    #     event_indices_list = []
    #     events = []
    #
    #     ts_array_list = []
    #
    #     for s, (raw_reader, dataroot) in enumerate(zip(raw_readers, original_dataroots)):
    #         ind = np.atleast_1d(evs.eegfile == dataroot)
    #         event_indices_list.append(ordered_indices[ind])
    #         events.append(evs[ind])
    #
    #         ts_array = raw_reader.read()
    #
    #         read_ok_mask = raw_reader.get_read_ok_mask()
    #
    #         ts_array_list.append(ts_array)
    #
    #     event_indices_array = np.hstack(event_indices_list)
    #
    #     event_indices_restore_sort_order_array = event_indices_array.argsort()
    #
    #     start_extend_time = time.time()
    #     # new code
    #     eventdata = xr.concat(ts_array_list, dim='start_offsets')
    #     # tdim = np.linspace(self.start_time-self.buffer_time,self.end_time+self.buffer_time,num=eventdata['offsets'].shape[0])
    #     # samplerate=eventdata.attrs['samplerate'].data
    #     samplerate = float(eventdata['samplerate'])
    #     tdim = np.arange(eventdata.shape[-1]) * (1.0 / samplerate) + (self.start_time - self.buffer_time)
    #     cdim = eventdata['channels']
    #     edim = np.concatenate(events).view(np.recarray).copy()
    #
    #     attrs = eventdata.attrs.copy()
    #     # constructing TimeSeries Object
    #     # eventdata = TimeSeriesX(eventdata.data,dims=['channels','events','time'],coords=[cdim,edim,tdim])
    #     eventdata = TimeSeriesX(eventdata.data,
    #                             dims=['channels', 'events', 'time'],
    #                             coords={'channels': cdim,
    #                                     'events': edim,
    #                                     'time': tdim,
    #                                     'samplerate': samplerate
    #                                     }
    #                             )
    #
    #     eventdata.attrs = attrs
    #
    #     # restoring original order of the events
    #     eventdata = eventdata[:, event_indices_restore_sort_order_array, :]
    #
    #     return eventdata

    def read(self):
        '''
        Calls read_events_data or read_session_data depending on user selection
        :return: TimeSeriesX object
        '''
        return self.read_fcn()
示例#15
0
class MorletWaveletFilter(PropertiedObject, BaseFilter):
    _descriptors = [
        TypeValTuple('freqs', np.ndarray, np.array([], dtype=np.float)),
        TypeValTuple('width', int, 5),
        TypeValTuple('output', str, ''),
        TypeValTuple('frequency_dim_pos', int, 0),
        # NOTE in this implementation the default position of frequency is -2
        TypeValTuple('verbose', bool, True),
    ]

    def __init__(self, time_series, **kwds):

        self.window = None
        self.time_series = time_series
        self.init_attrs(kwds)

        self.compute_power_and_phase_fcn = None

        if self.output == 'power':
            self.compute_power_and_phase_fcn = self.compute_power
        elif self.output == 'phase':
            self.compute_power_and_phase_fcn = self.compute_phase
        else:
            self.compute_power_and_phase_fcn = self.compute_power_and_phase

    def all_but_time_iterator(self, array):
        from itertools import product
        sizes_except_time = np.asarray(array.shape)[:-1]
        ranges = map(lambda size: xrange(size), sizes_except_time)
        for cart_prod_idx_tuple in product(*ranges):
            yield cart_prod_idx_tuple, array[cart_prod_idx_tuple]

    def resample_time_axis(self):
        from ptsa.data.filters.ResampleFilter import ResampleFilter

        rs_time_axis = None  # resampled time axis
        if self.resamplerate > 0:

            rs_time_filter = ResampleFilter(resamplerate=self.resamplerate)
            rs_time_filter.set_input(self.time_series[0, 0, :])
            time_series_resampled = rs_time_filter.filter()
            rs_time_axis = time_series_resampled['time']
        else:
            rs_time_axis = self.time_series['time']

        return rs_time_axis, self.time_series['time']

    def allocate_output_arrays(self, time_axis_size):
        array_type = np.float32
        shape = self.time_series.shape[:-1] + (
            self.freqs.shape[0],
            time_axis_size,
        )

        if self.output == 'power':
            return np.empty(shape=shape, dtype=array_type), None
        elif self.output == 'phase':
            return None, np.empty(shape=shape, dtype=array_type)
        else:
            return np.empty(shape=shape,
                            dtype=array_type), np.empty(shape=shape,
                                                        dtype=array_type)

    def compute_power(self, wavelet_coef_array):
        # return wavelet_coef_array.real ** 2 + wavelet_coef_array.imag ** 2, None
        return np.abs(wavelet_coef_array)**2, None
        # # wavelet_coef_array.real ** 2 + wavelet_coef_array.imag ** 2, None

    def compute_phase(self, wavelet_coef_array):
        return None, np.angle(wavelet_coef_array)

    def compute_power_and_phase(self, wavelet_coef_array):
        return wavelet_coef_array.real**2 + wavelet_coef_array.imag**2, np.angle(
            wavelet_coef_array)

    def store(self, idx_tuple, target_array, source_array):
        if source_array is not None:
            target_array[idx_tuple] = source_array

    def get_data_iterator(self):
        return self.all_but_time_iterator(self.time_series)

    def construct_output_array(self, array, dims, coords):
        out_array = xray.DataArray(array, dims=dims, coords=coords)
        # out_array.attrs['samplerate'] = self.time_series.attrs['samplerate']
        out_array['samplerate'] = self.time_series['samplerate']
        return out_array

    def build_output_arrays(self, wavelet_pow_array, wavelet_phase_array,
                            time_axis):
        wavelet_pow_array_xray = None
        wavelet_phase_array_xray = None

        if isinstance(self.time_series, xray.DataArray):

            dims = list(self.time_series.dims[:-1] + (
                'frequency',
                'time',
            ))

            transposed_dims = []

            # NOTE all computaitons up till this point assume that frequency position is -2 whereas
            # the default setting for this filter sets frequency axis index to 0. To avoid unnecessary transpositions
            # we need to adjust position of the frequency axis in the internal computations

            # getting frequency dim position as positive integer
            self.frequency_dim_pos = (len(dims) +
                                      self.frequency_dim_pos) % len(dims)
            orig_frequency_idx = dims.index('frequency')

            if self.frequency_dim_pos != orig_frequency_idx:
                transposed_dims = dims[:orig_frequency_idx] + dims[
                    orig_frequency_idx + 1:]
                transposed_dims.insert(self.frequency_dim_pos, 'frequency')

            coords = {
                dim_name: self.time_series.coords[dim_name]
                for dim_name in self.time_series.dims[:-1]
            }
            coords['frequency'] = self.freqs
            coords['time'] = time_axis

            if 'offsets' in self.time_series.coords.keys():
                coords['offsets'] = ('time', self.time_series['offsets'])

            if wavelet_pow_array is not None:
                wavelet_pow_array_xray = self.construct_output_array(
                    wavelet_pow_array, dims=dims, coords=coords)
            if wavelet_phase_array is not None:
                wavelet_phase_array_xray = self.construct_output_array(
                    wavelet_phase_array, dims=dims, coords=coords)

            if wavelet_pow_array_xray is not None:
                wavelet_pow_array_xray = TimeSeriesX(wavelet_pow_array_xray)
                if len(transposed_dims):
                    wavelet_pow_array_xray = wavelet_pow_array_xray.transpose(
                        *transposed_dims)

                wavelet_pow_array_xray.attrs = self.time_series.attrs.copy()

            if wavelet_phase_array_xray is not None:
                wavelet_phase_array_xray = TimeSeriesX(
                    wavelet_phase_array_xray)
                if len(transposed_dims):
                    wavelet_phase_array_xray = wavelet_phase_array_xray.transpose(
                        *transposed_dims)

                wavelet_phase_array_xray.attrs = self.time_series.attrs.copy()

            return wavelet_pow_array_xray, wavelet_phase_array_xray

    def compute_wavelet_ffts(self):

        # samplerate = self.time_series.attrs['samplerate']
        samplerate = float(self.time_series['samplerate'])

        freqs = np.atleast_1d(self.freqs)

        wavelets = morlet_multi(freqs=freqs,
                                widths=self.width,
                                samplerates=samplerate)
        # ADD WARNING HERE FROM PHASE_MULTI

        num_wavelets = len(wavelets)

        # computing length of the longest wavelet
        s_w = max(map(lambda wavelet: wavelet.shape[0], wavelets))

        time_series_length = self.time_series['time'].shape[0]

        if s_w > self.time_series['time'].shape[0]:
            raise ValueError(
                'Time series length (l_ts=%s) is shorter than maximum wavelet length (l_w=%s). '
                'Please use longer time series or increase lowest wavelet frequency '
                % (time_series_length, s_w))

        # length of the tie axis of the time series
        s_d = self.time_series['time'].shape[0]

        # determine the size based on the next power of 2
        convolution_size = s_w + s_d - 1
        convolution_size_pow2 = np.power(2, next_pow2(convolution_size))

        # preallocating arrays
        # wavelet_fft_array = np.empty(shape=(num_wavelets, convolution_size_pow2), dtype=np.complex64)
        wavelet_fft_array = np.empty(shape=(num_wavelets,
                                            convolution_size_pow2),
                                     dtype=np.complex)
        convolution_size_array = np.empty(shape=(num_wavelets), dtype=np.int)

        # computting wavelet ffts
        for i, wavelet in enumerate(wavelets):
            wavelet_fft_array[i] = fft(wavelet, convolution_size_pow2)
            convolution_size_array[i] = wavelet.shape[0] + s_d - 1

        return wavelet_fft_array, convolution_size_array, convolution_size_pow2

    def filter(self):

        data_iterator = self.get_data_iterator()

        time_axis = self.time_series['time']

        time_axis_size = time_axis.shape[0]

        wavelet_pow_array, wavelet_phase_array = self.allocate_output_arrays(
            time_axis_size=time_axis_size)

        # preallocating array
        wavelet_coef_single_array = np.empty(shape=(time_axis_size),
                                             dtype=np.complex64)

        wavelet_fft_array, convolution_size_array, convolution_size_pow2 = self.compute_wavelet_ffts(
        )
        num_wavelets = wavelet_fft_array.shape[0]

        wavelet_start = time.time()

        for idx_tuple, signal in data_iterator:

            signal_fft = fft(signal, convolution_size_pow2)

            for w in xrange(num_wavelets):
                signal_wavelet_conv = ifft(wavelet_fft_array[w] * signal_fft)

                # computting trim indices for the wavelet_coeff array
                start_offset = (convolution_size_array[w] - time_axis_size) / 2
                end_offset = start_offset + time_axis_size

                wavelet_coef_single_array[:] = signal_wavelet_conv[
                    start_offset:end_offset]

                out_idx_tuple = idx_tuple + (w, )

                pow_array_single, phase_array_single = self.compute_power_and_phase_fcn(
                    wavelet_coef_single_array)

                self.store(out_idx_tuple, wavelet_pow_array, pow_array_single)
                self.store(out_idx_tuple, wavelet_phase_array,
                           phase_array_single)

        if self.verbose:
            print 'total time wavelet loop: ', time.time() - wavelet_start

        return self.build_output_arrays(wavelet_pow_array, wavelet_phase_array,
                                        time_axis)
示例#16
0
class TalReader(PropertiedObject, BaseReader):
    '''
    Reader that reads tal structs Matlab file and converts it to numpy recarray
    '''
    _descriptors = [
        TypeValTuple('filename', str, ''),
        TypeValTuple('struct_name', str, 'bpTalStruct'),
    ]

    def __init__(self, **kwds):
        '''
        Constructor:

        :param kwds:allowed values are:
        -------------------------------------
        :param filename {str} -  path to tal file
        :param struct_name {str} -  name of the matlab struct to load
        :return: None
        '''

        self.init_attrs(kwds)
        self.bipolar_channels = None

        self.tal_structs_array = None

    def get_bipolar_pairs(self):
        '''

        :return: numpy recarray where each record has two fields 'ch0' and 'ch1' storing  channel labels.
        '''
        if self.bipolar_channels is None:
            if self.tal_structs_array is None:
                self.read()
            self.initialize_bipolar_pairs()

        return self.bipolar_channels

    def get_monopolar_channels(self):
        '''

        :return: numpy array of monopolar channel labels
        '''
        bipolar_array = self.get_bipolar_pairs()
        monopolar_set = set(
            list(bipolar_array['ch0']) + list(bipolar_array['ch1']))
        return np.array(sorted(list(monopolar_set)))

    def initialize_bipolar_pairs(self):
        # initialize bipolar pairs
        self.bipolar_channels = np.recarray(shape=(len(self.tal_struct_array)),
                                            dtype=[('ch0', '|S3'),
                                                   ('ch1', '|S3')])

        channel_record_array = self.tal_struct_array['channel']
        for i, channel_array in enumerate(channel_record_array):
            self.bipolar_channels[i] = tuple(
                map(lambda x: str(x).zfill(3), channel_array))

    def read(self):
        '''

        :return:np.recarray representing tal struct array (originally defined in Matlab file)
        '''

        from ptsa.data.MatlabIO import read_single_matlab_matrix_as_numpy_structured_array

        struct_names = ['bpTalStruct', 'subjTalEvents']
        # struct_names = ['bpTalStruct']
        if self.struct_name not in struct_names:
            self.tal_struct_array = read_single_matlab_matrix_as_numpy_structured_array(
                self.filename, self.struct_name, verbose=False)

            if self.tal_struct_array is not None:
                return self.tal_struct_array
            else:
                raise AttributeError(
                    'Could not read tal struct data for the specified struct_name='
                    + self.struct_name)

        else:

            for sn in struct_names:
                self.tal_struct_array = read_single_matlab_matrix_as_numpy_structured_array(
                    self.filename, sn, verbose=False)
                if self.tal_struct_array is not None:
                    return self.tal_struct_array

        raise AttributeError(
            'Could not read tal struct data. Try specifying struct_name argument :'
            '\nTalReader(filename=e_path, struct_name=<name_of_struc_to_read>)'
        )
示例#17
0
class ResampleFilter(PropertiedObject, BaseFilter):
    """Upsample or downsample a time series to a new sample rate.

    Keyword Arguments
    -----------------
    time_series
        TimeSeriesX object
    resamplerate: float
        new sampling frequency
    time_axis_index: int
        index of the time axis
    round_to_original_timepoints: bool
        Flag indicating if timepoints from original time axis
        should be reused after proper rounding. Defaults to False

"""

    _descriptors = [
        TypeValTuple('time_series', TimeSeriesX,
                     TimeSeriesX([0.0], dict(samplerate=1), dims=['time'])),
        TypeValTuple('resamplerate', float, -1.0),
        TypeValTuple('time_axis_index', int, -1),
        TypeValTuple('round_to_original_timepoints', bool, False),
    ]

    def ___syntax_helper(self):
        self.time_series = None
        self.resamplerate = None
        self.time_axis_index = None
        self.round_to_original_timepoints = None

    def __init__(self, **kwds):
        self.window = None
        # self.time_series = None
        self.init_attrs(kwds)

    def filter(self):
        """resamples time series

        Returns
        -------
        resampled: TimeSeriesX
            resampled time series with sampling frequency set to resamplerate

        """
        samplerate = float(self.time_series['samplerate'])

        time_axis_length = np.squeeze(self.time_series.coords['time'].shape)
        new_length = int(
            np.round(time_axis_length * self.resamplerate / samplerate))

        print(new_length)

        if self.time_axis_index < 0:
            self.time_axis_index = self.time_series.get_axis_num('time')

        time_axis = self.time_series.coords[self.time_series.dims[
            self.time_axis_index]]

        try:
            time_axis_data = time_axis.data[
                'time']  # time axis can be recarray with one of the arrays being time
        except (KeyError, IndexError) as excp:
            # if we get here then most likely time axis is ndarray of floats
            time_axis_data = time_axis.data

        time_idx_array = np.arange(len(time_axis))

        if self.round_to_original_timepoints:
            filtered_array, new_time_idx_array = resample(
                self.time_series.data,
                new_length,
                t=time_idx_array,
                axis=self.time_axis_index,
                window=self.window)

            # print new_time_axis

            new_time_idx_array = np.rint(new_time_idx_array).astype(np.int)

            new_time_axis = time_axis[new_time_idx_array]

        else:
            filtered_array, new_time_axis = resample(self.time_series.data,
                                                     new_length,
                                                     t=time_axis_data,
                                                     axis=self.time_axis_index,
                                                     window=self.window)

        coords = {}
        for i, dim_name in enumerate(self.time_series.dims):
            if i != self.time_axis_index:
                coords[dim_name] = self.time_series.coords[dim_name].copy()
            else:
                coords[dim_name] = new_time_axis
        coords['samplerate'] = self.resamplerate

        filtered_time_series = TimeSeriesX(filtered_array,
                                           coords=coords,
                                           dims=self.time_series.dims)
        return filtered_time_series
示例#18
0
class BaseEventReader(PropertiedObject, BaseReader):
    '''
    Reader class that reads event file and returns them as np.recarray
    '''
    _descriptors = [
        TypeValTuple('filename', str, ''),
        TypeValTuple('eliminate_events_with_no_eeg', bool, True),
        TypeValTuple('eliminate_nans', bool, True),
        TypeValTuple('use_reref_eeg', bool, False),
    ]

    def __init__(self, **kwds):
        '''
        Constructor:

        :param kwds:allowed values are:
        -------------------------------------
        :param filename {str} -  path to event file
        :param eliminate_events_with_no_eeg {bool} - flag to automatically remove events with no eegfile (default True)
        :param eliminate_nans {bool} - flag to automatically replace nans in the event structs with -999 (default True)
        :param use_reref_eeg {bool} -  flag that changes eegfiles to point reref eegs. Default is False and eegs read
        are nonreref ones

        :return: None
        '''
        self.init_attrs(kwds)

    def correct_eegfile_field(self, events):
        '''
        Replaces 'eeg.reref' with 'eeg.noreref' in eegfile path
        :param events: np.recarray representing events. One of hte field of this array should be eegfile
        :return:
        '''
        data_dir_bad = r'/data.*/' + events[0].subject + r'/eeg'
        data_dir_good = r'/data/eeg/' + events[0].subject + r'/eeg'
        for ev in events:
            ev.eegfile = ev.eegfile.replace('eeg.reref', 'eeg.noreref')
            ev.eegfile = re.sub(data_dir_bad, data_dir_good, ev.eegfile)
        return events

    def read(self):
        '''
        Reads Matlab event file and returns corresponging np.recarray. Path to the eegfile is changed
        w.r.t original Matlab code to account for the following:
        1. /data dir of the database might have been mounted under different mount point e.g. /Users/m/data
        2. use_reref_eeg is set to True in which case we replaces 'eeg.reref' with 'eeg.noreref' in eegfile path

        :return: np.recarray representing events
        '''
        from ptsa.data.MatlabIO import read_single_matlab_matrix_as_numpy_structured_array

        # extract matlab matrix (called 'events') as numpy structured array
        struct_array = read_single_matlab_matrix_as_numpy_structured_array(
            self.filename, 'events')

        evs = struct_array

        if self.eliminate_events_with_no_eeg:

            # eliminating events that have no eeg file
            indicator = np.empty(len(evs), dtype=bool)
            indicator[:] = False

            for i, ev in enumerate(evs):
                # MAKE THIS CHECK STRONGER
                indicator[i] = (len(str(evs[i].eegfile)) > 3)
                # indicator[i] = (type(evs[i].eegfile).__name__.startswith('unicode')) & (len(str(evs[i].eegfile)) > 3)

            evs = evs[indicator]

        # determining data_dir_prefix in case rhino /data filesystem was mounted under different root
        data_dir_prefix = self.find_data_dir_prefix()
        for i, ev in enumerate(evs):
            ev.eegfile = join(data_dir_prefix,
                              str(pathlib.Path(str(ev.eegfile)).parts[1:]))

        if not self.use_reref_eeg:
            evs = self.correct_eegfile_field(evs)

        if self.eliminate_nans:
            # this is
            evs = self.replace_nans(evs)

        return evs

    def replace_nans(self, evs, replacement_val=-999):

        for descr in evs.dtype.descr:
            field_name = descr[0]

            try:
                nan_selector = np.isnan(evs[field_name])
                evs[field_name][nan_selector] = replacement_val
            except TypeError:
                pass
        return evs

    def find_data_dir_prefix(self):
        '''
        determining dir_prefix

        data on rhino database is mounted as /data
        copying rhino /data structure to another directory will cause all files in data have new prefix
        example:
        self._filename='/Users/m/data/events/R1060M_events.mat'
        prefix is '/Users/m'
        we use find_dir_prefix to determine prefix based on common_root in path with and without prefix

        :return: data directory prefix
        '''

        common_root = 'data/events'
        prefix = find_dir_prefix(path_with_prefix=self._filename,
                                 common_root=common_root)
        if not prefix:
            raise RuntimeError(
                'Could not determine prefix from: %s using common_root: %s' %
                (self._filename, common_root))

        return find_dir_prefix(self._filename, 'data/events')
示例#19
0
class ParamsReader(PropertiedObject, BaseReader):
    '''
    Reader for parameter file (e.g. params.txt)
    '''
    _descriptors = [
        TypeValTuple('filename', str, ''),
        TypeValTuple('dataroot', str, ''),
    ]

    def __init__(self, **kwds):
        '''
        Constructor
        :param kwds:allowed values are:
        -------------------------------------
        :param filename {str} -  path t params file
        :param dataroot {str} -  core name of the eegfiles

        :return: None
        '''
        self.init_attrs(kwds)

        if self.filename:
            if not isfile(self.filename):
                raise IOError('Could not open params file: %s' % self.filename)

        elif self.dataroot:
            self.filename = self.locate_params_file(dataroot=self.dataroot)
        else:
            raise IOError(
                'Could not find params file using dataroot: %s or using direct path:%s'
                % (self.dataroot, self.filename))

        Converter = collections.namedtuple('Converter', ['convert', 'name'])
        self.param_to_convert_fcn = {
            'samplerate':
            Converter(convert=float, name='samplerate'),
            'gain':
            Converter(convert=float, name='gain'),
            'format':
            Converter(convert=lambda s: s.replace("'", "").replace('"', ''),
                      name='format'),
            'dataformat':
            Converter(convert=lambda s: s.replace("'", "").replace('"', ''),
                      name='format')
        }

    def locate_params_file(self, dataroot):
        """
        Identifies exact path to param file.
        :param dataroot: {str} eeg core file name
        :return: {str}
        """

        for param_file in (abspath(dataroot + '.params'),
                           abspath(join(dirname(dataroot), 'params.txt'))):

            if isfile(param_file):
                return param_file

        raise IOError('No params file found in ' + str(dir) +
                      '. Params files must be in the same directory ' +
                      'as the EEG data and must be named .params ' +
                      'or params.txt.')

    def read(self):
        """
        Parses param file
        :return: {dict} dictionary with param file content
        """
        params = {}
        param_file = self.filename

        # we have a file, so open and process it
        for line in open(param_file, 'r').readlines():
            # get the columns by splitting
            param_name, str_to_convert = line.strip().split()[:2]
            try:
                convert_tuple = self.param_to_convert_fcn[param_name]
                params[convert_tuple.name] = convert_tuple.convert(
                    str_to_convert)
            except KeyError:
                pass
        if not set(params.keys()).issuperset(set(['gain', 'samplerate'])):
            raise ValueError(
                'Params file must contain samplerate and gain!\n' +
                'The following fields were supplied:\n' + str(params.keys()))

        return params
示例#20
0
class BaseEventReader(PropertiedObject, BaseReader):
    """Reader class that reads event file and returns them as np.recarray.

    Keyword arguments
    -----------------
    filename : str
        path to event file
    eliminate_events_with_no_eeg : bool
        flag to automatically remove events with no eegfile (default True)
    eliminate_nans : bool
        flag to automatically replace nans in the event structs with -999 (default True)
    use_reref_eeg : bool
        flag that changes eegfiles to point reref eegs. Default is False
        and eegs read are nonreref ones
    normalize_eeg_path : bool
        flag that determines if 'data1', 'data2', etc... in eeg path will
        get converted to 'data'. The flag is True by default meaning all
        'data1', 'data2', etc... are converted to 'data'
    common_root : str
        partial path to root events folder e.g. if you events are placed in
        /data/events/RAM_FR1 the path should be 'data/events'. If your
        events are placed in the '/data/scalp_events/catFR' the common root
        should be 'data/scalp_events'. Note that you do not include opening
        '/' in the common_root

    """
    _descriptors = [
        TypeValTuple('filename', six.string_types, ''),
        TypeValTuple('eliminate_events_with_no_eeg', bool, True),
        TypeValTuple('eliminate_nans', bool, True),
        TypeValTuple('use_reref_eeg', bool, False),
        TypeValTuple('normalize_eeg_path', bool, True),
        TypeValTuple('common_root', six.string_types, 'data/events')
    ]

    def __init__(self, **kwds):
        self.init_attrs(kwds)
        self._alter_eeg_path_flag = not self.use_reref_eeg

    @property
    def alter_eeg_path_flag(self):
        return self._alter_eeg_path_flag

    @alter_eeg_path_flag.setter
    def alter_eeg_path_flag(self, val):
        self._alter_eeg_path_flag = val
        self.use_reref_eeg = not self._alter_eeg_path_flag

    def normalize_paths(self, events):
        """
        Replaces data1, data2 etc... in the eegfile column of the events with data
        :param events: np.recarray representing events. One of hte field of this array should be eegfile
        :return: None
        """
        subject = events[0].subject
        if sys.platform.startswith('win'):
            data_dir_bad = r'\\data.*\\' + subject + r'\\eeg'
            data_dir_good = r'\\data\\eeg\\' + subject + r'\\eeg'
        else:
            data_dir_bad = r'/data.*/' + subject + r'/eeg'
            data_dir_good = r'/data/eeg/' + subject + r'/eeg'

        for ev in events:
            # ev.eegfile = ev.eegfile.replace('eeg.reref', 'eeg.noreref')
            ev.eegfile = re.sub(data_dir_bad, data_dir_good, ev.eegfile)
        return events

    def modify_eeg_path(self, events):
        """
        Replaces 'eeg.reref' with 'eeg.noreref' in eegfile path
        :param events: np.recarray representing events. One of hte field of this array should be eegfile
        :return:None
        """

        for ev in events:
            ev.eegfile = ev.eegfile.replace('eeg.reref', 'eeg.noreref')
        return events

    def read(self):
        if os.path.splitext(self.filename)[-1] == '.json':
            return self.read_json()
        else:
            return self.read_matlab()

    def as_dataframe(self):
        """Read events and return as a :class:`pd.DataFrame`."""
        events = self.read()
        return pd.DataFrame(events)

    def check_reader_settings_for_json_read(self):

        if self.use_reref_eeg:
            raise NotImplementedError('Reref from JSON not implemented')

    def read_json(self):

        self.check_reader_settings_for_json_read()

        evs = self.from_json(self.filename)

        if self.eliminate_events_with_no_eeg:
            # eliminating events that have no eeg file
            indicator = np.empty(len(evs), dtype=bool)
            indicator[:] = False

            for i, ev in enumerate(evs):
                # MAKE THIS CHECK STRONGER
                indicator[i] = (len(str(evs[i].eegfile)) > 3)

            evs = evs[indicator]

        if 'eegfile' in evs.dtype.names:
            eeg_dir = os.path.join(os.path.dirname(self.filename), '..', '..',
                                   'ephys', 'current_processed', 'noreref')
            eeg_dir = os.path.abspath(eeg_dir)
            for ev in evs:
                ev.eegfile = os.path.join(eeg_dir, ev.eegfile)

        return evs

    def read_matlab(self):
        """
        Reads Matlab event file and returns corresponging np.recarray. Path to the eegfile is changed
        w.r.t original Matlab code to account for the following:
        1. /data dir of the database might have been mounted under different mount point e.g. /Users/m/data
        2. use_reref_eeg is set to True in which case we replaces 'eeg.reref' with 'eeg.noreref' in eegfile path

        :return: np.recarray representing events
        """
        # extract matlab matrix (called 'events') as numpy structured array
        struct_array = read_single_matlab_matrix_as_numpy_structured_array(
            self.filename, 'events')

        evs = struct_array

        if 'eegfile' in evs.dtype.names:
            if self.eliminate_events_with_no_eeg:

                # eliminating events that have no eeg file
                indicator = np.empty(len(evs), dtype=bool)
                indicator[:] = False

                for i, ev in enumerate(evs):
                    # MAKE THIS CHECK STRONGER
                    indicator[i] = (len(str(evs[i].eegfile)) > 3)
                    # indicator[i] = (type(evs[i].eegfile).__name__.startswith('unicode')) & (len(str(evs[i].eegfile)) > 3)

                evs = evs[indicator]

            # determining data_dir_prefix in case rhino /data filesystem was mounted under different root
            if self.normalize_eeg_path:
                data_dir_prefix = self.find_data_dir_prefix()
                for i, ev in enumerate(evs):
                    ev.eegfile = join(
                        data_dir_prefix,
                        str(pathlib.Path(str(ev.eegfile)).parts[1:]))

                evs = self.normalize_paths(evs)

            # if not self.use_reref_eeg:
            if self._alter_eeg_path_flag:
                evs = self.modify_eeg_path(evs)

        if self.eliminate_nans:
            # this is
            evs = self.replace_nans(evs)

        return evs

    def replace_nans(self, evs, replacement_val=-999):

        for descr in evs.dtype.descr:
            field_name = descr[0]

            try:
                nan_selector = np.isnan(evs[field_name])
                evs[field_name][nan_selector] = replacement_val
            except TypeError:
                pass
        return evs

    def find_data_dir_prefix(self):
        """
        determining dir_prefix

        data on rhino database is mounted as /data
        copying rhino /data structure to another directory will cause all files in data have new prefix
        example:
        self._filename='/Users/m/data/events/R1060M_events.mat'
        prefix is '/Users/m'
        we use find_dir_prefix to determine prefix based on common_root in path with and without prefix

        :return: data directory prefix
        """

        prefix = find_dir_prefix(path_with_prefix=self._filename,
                                 common_root=self.common_root)
        if not prefix:
            raise RuntimeError(
                'Could not determine prefix from: %s using common_root: %s' %
                (self._filename, self.common_root))

        return find_dir_prefix(self._filename, self.common_root)

    ### TODO: CLEAN UP, COMMENT

    @classmethod
    def get_element_dtype(cls, element):
        if isinstance(element, dict):
            return cls.mkdtype(element)
        elif isinstance(element, int):
            return 'int64'
        elif isinstance(element, six.string_types):
            return 'S256'
        elif isinstance(element, bool):
            return 'b'
        elif isinstance(element, float):
            return 'float64'
        elif isinstance(element, list):
            return cls.get_element_dtype(element[0])
        else:
            raise Exception('Could not convert type %s' % type(element))

    @classmethod
    def mkdtype(cls, d):
        if isinstance(d, list):
            dtype = cls.mkdtype(d[0])
            return dtype
        dtype = []

        for k, v in list(d.items()):
            dtype.append((str(k), cls.get_element_dtype(v)))

        return np.dtype(dtype)

    @classmethod
    def from_json(cls, json_filename):
        d = json.load(open(json_filename))
        return cls.from_dict(d)

    @classmethod
    def from_dict(cls, d):
        if not isinstance(d, list):
            d = [d]

        list_names = []

        for k, v in list(d[0].items()):
            if isinstance(v, list):
                list_names.append(k)

        list_info = defaultdict(lambda *_: {'len': 0, 'dtype': None})

        for entry in d:
            for k in list_names:
                list_info[k]['len'] = max(list_info[k]['len'], len(entry[k]))
                if not list_info[k]['dtype'] and len(entry[k]) > 0:
                    if isinstance(entry[k][0], dict):
                        list_info[k]['dtype'] = cls.mkdtype(entry[k][0])
                    else:
                        list_info[k]['dtype'] = cls.get_element_dtype(entry[k])

        dtypes = []
        for k, v in list(d[0].items()):
            if not k in list_info:
                dtypes.append((str(k), cls.get_element_dtype(v)))
            else:
                dtypes.append(
                    (str(k), list_info[k]['dtype'], list_info[k]['len']))

        if dtypes:
            arr = np.zeros(len(d), dtypes).view(np.recarray)
            cls.copy_values(d, arr, list_info)
        else:
            arr = np.array([])
        return arr.view(np.recarray)

    @classmethod
    def copy_values(cls, dict_list, rec_arr, list_info=None):
        if len(dict_list) == 0:
            return

        dict_fields = {}
        for k, v, in list(dict_list[0].items()):
            if isinstance(v, dict):
                dict_fields[k] = [inner_dict[k] for inner_dict in dict_list]

        for i, sub_dict in enumerate(dict_list):
            for k, v in list(sub_dict.items()):
                if k in dict_fields or list_info and k in list_info:
                    continue

                if isinstance(v, dict):
                    cls.copy_values([v], rec_arr[i][k])
                elif isinstance(v, six.string_types):
                    rec_arr[i][k] = cls.strip_accents(v)
                else:
                    rec_arr[i][k] = v

        for i, sub_dict in enumerate(dict_list):
            for k, v in list(sub_dict.items()):
                if list_info and k in list_info:
                    arr = np.zeros(list_info[k]['len'], list_info[k]['dtype'])
                    if len(v) > 0:
                        if isinstance(v[0], dict):
                            cls.copy_values(v, arr)
                        else:
                            for j, element in enumerate(v):
                                arr[j] = element

                    rec_arr[i][k] = arr.view(np.recarray)

        for k, v in list(dict_fields.items()):
            cls.copy_values(v, rec_arr[k])

    @classmethod
    def strip_accents(cls, s):
        try:
            return str(''.join(
                c for c in unicodedata.normalize('NFD', six.text_type(s))
                if unicodedata.category(c) != 'Mn'))
        except UnicodeError:  # If accents can't be converted, just remove them
            return str(re.sub(r'[^A-Za-z0-9 -_.]', '', s))
示例#21
0
class MorletWaveletFilterCpp(PropertiedObject, BaseFilter):
    _descriptors = [
        TypeValTuple('freqs', np.ndarray, np.array([], dtype=np.float)),
        TypeValTuple('width', int, 5),
        TypeValTuple('output', str, 'power'),
        TypeValTuple('frequency_dim_pos', int, 0),
        TypeValTuple('cpus', int, 1),
        # NOTE in this implementation the default position of frequency is -2
        TypeValTuple('verbose', bool, True),
    ]

    def __init__(self, time_series, **kwds):

        self.window = None
        self.time_series = time_series
        self.init_attrs(kwds)

    def filter(self):

        time_axis = self.time_series['time']

        time_axis_size = time_axis.shape[0]
        samplerate = float(self.time_series['samplerate'])

        wavelet_dims = self.time_series.shape[:-1] + (self.freqs.shape[0], )
        print(wavelet_dims)

        powers_reshaped = np.array([[]], dtype=np.float)
        phases_reshaped = np.array([[]], dtype=np.float)
        wavelets_complex_reshaped = np.array([[]], dtype=np.complex)

        if self.output == 'power':
            powers_reshaped = np.empty(shape=(np.prod(wavelet_dims),
                                              self.time_series.shape[-1]),
                                       dtype=np.float)
        if self.output == 'phase':
            phases_reshaped = np.empty(shape=(np.prod(wavelet_dims),
                                              self.time_series.shape[-1]),
                                       dtype=np.float)
        if self.output == 'both':
            powers_reshaped = np.empty(shape=(np.prod(wavelet_dims),
                                              self.time_series.shape[-1]),
                                       dtype=np.float)
            phases_reshaped = np.empty(shape=(np.prod(wavelet_dims),
                                              self.time_series.shape[-1]),
                                       dtype=np.float)
        if self.output == 'complex':
            wavelets_complex_reshaped = np.empty(
                shape=(np.prod(wavelet_dims), self.time_series.shape[-1]),
                dtype=np.complex)

        # mt = morlet.MorletWaveletTransformMP(self.cpus)
        # mt = MorletWaveletTransformMP(self.cpus)
        mt = MorletWaveletTransformMP(self.cpus)

        time_series_reshaped = np.ascontiguousarray(
            self.time_series.data.reshape(np.prod(self.time_series.shape[:-1]),
                                          self.time_series.shape[-1]),
            self.time_series.data.dtype)
        if self.output == 'power':
            mt.set_output_type(morlet.POWER)
        if self.output == 'phase':
            mt.set_output_type(morlet.PHASE)
        if self.output == 'both':
            mt.set_output_type(morlet.BOTH)
        if self.output == 'complex':
            mt.set_output_type(morlet.COMPLEX)

        mt.set_signal_array(time_series_reshaped)
        mt.set_wavelet_pow_array(powers_reshaped)
        mt.set_wavelet_phase_array(phases_reshaped)
        mt.set_wavelet_complex_array(wavelets_complex_reshaped)

        # mt.initialize_arrays(time_series_reshaped, wavelets_reshaped)

        mt.initialize_signal_props(float(self.time_series['samplerate']))
        mt.initialize_wavelet_props(self.width, self.freqs)
        mt.prepare_run()

        s = time.time()
        mt.compute_wavelets_threads()

        powers_final = None
        phases_final = None
        wavelet_complex_final = None

        if self.output == 'power':
            powers_final = powers_reshaped.reshape(
                wavelet_dims + (self.time_series.shape[-1], ))
        if self.output == 'phase':
            phases_final = phases_reshaped.reshape(
                wavelet_dims + (self.time_series.shape[-1], ))
        if self.output == 'both':
            powers_final = powers_reshaped.reshape(
                wavelet_dims + (self.time_series.shape[-1], ))
            phases_final = phases_reshaped.reshape(
                wavelet_dims + (self.time_series.shape[-1], ))
        if self.output == 'complex':
            wavelet_complex_final = wavelets_complex_reshaped.reshape(
                wavelet_dims + (self.time_series.shape[-1], ))

        # wavelets_final = powers_reshaped.reshape( wavelet_dims+(self.time_series.shape[-1],) )

        coords = {k: v for k, v in list(self.time_series.coords.items())}
        coords['frequency'] = self.freqs

        powers_ts = None
        phases_ts = None
        wavelet_complex_ts = None

        if powers_final is not None:
            powers_ts = TimeSeriesX(powers_final,
                                    dims=self.time_series.dims[:-1] + (
                                        'frequency',
                                        self.time_series.dims[-1],
                                    ),
                                    coords=coords)
            final_dims = (powers_ts.dims[-2], ) + powers_ts.dims[:-2] + (
                powers_ts.dims[-1], )

            powers_ts = powers_ts.transpose(*final_dims)

        if phases_final is not None:
            phases_ts = TimeSeriesX(phases_final,
                                    dims=self.time_series.dims[:-1] + (
                                        'frequency',
                                        self.time_series.dims[-1],
                                    ),
                                    coords=coords)

            final_dims = (phases_ts.dims[-2], ) + phases_ts.dims[:-2] + (
                phases_ts.dims[-1], )

            phases_ts = phases_ts.transpose(*final_dims)

        if wavelet_complex_final is not None:
            wavelet_complex_ts = TimeSeriesX(wavelet_complex_final,
                                             dims=self.time_series.dims[:-1] +
                                             (
                                                 'frequency',
                                                 self.time_series.dims[-1],
                                             ),
                                             coords=coords)

            final_dims = (wavelet_complex_ts.dims[-2],
                          ) + wavelet_complex_ts.dims[:-2] + (
                              wavelet_complex_ts.dims[-1], )

            wavelet_complex_ts = wavelet_complex_ts.transpose(*final_dims)

        if self.verbose:
            print('CPP total time wavelet loop: ', time.time() - s)

        if wavelet_complex_ts is not None:
            return wavelet_complex_ts, None
        else:
            return powers_ts, phases_ts
示例#22
0
class ParamsReader(PropertiedObject, BaseReader):
    """
    Reader for parameter file (e.g. params.txt)
    """
    _descriptors = [
        TypeValTuple('filename', six.string_types, ''),
        TypeValTuple('dataroot', six.string_types, ''),
    ]

    def __init__(self, **kwds):
        """
        Constructor
        :param kwds:allowed values are:
        -------------------------------------
        :param filename {str} -  path t params file
        :param dataroot {str} -  core name of the eegfiles

        :return: None
        """
        self.init_attrs(kwds)

        if self.filename:
            if not isfile(self.filename):
                raise IOError('Could not open params file: %s' % self.filename)

        elif self.dataroot:
            self.filename = self.locate_params_file(dataroot=self.dataroot)
        else:
            raise IOError('Could not find params file using dataroot: %s or using direct path:%s' % (
            self.dataroot, self.filename))

        if splitext(self.filename)[-1] == '.txt':
            Converter = collections.namedtuple('Converter', ['convert', 'name'])
            self.param_to_convert_fcn = {
                'samplerate': Converter(convert=float, name='samplerate'),
                'gain': Converter(convert=float, name='gain'),
                'format': Converter(convert=lambda s: s.replace("'", "").replace('"', ''), name='format'),
                'dataformat': Converter(convert=lambda s: s.replace("'", "").replace('"', ''), name='format')
            }

    def locate_params_file(self, dataroot):
        """
        Identifies exact path to param file.
        :param dataroot: {str} eeg core file name
        :return: {str}
        """

        for param_file in (abspath(dataroot + '.params'),
                           abspath(join(dirname(dataroot), 'params.txt'))):

            if isfile(param_file):
                return param_file

        param_file = join(dirname(dataroot), '..', 'sources.json')
        if isfile(param_file):
            return param_file

        raise IOError('No params file found in ' + str(dataroot) +
                      '. Params files must be in the same directory ' +
                      'as the EEG data and must be named .params ' +
                      'or params.txt, or in the directory above and '
                      'named sources.json')

    def read(self):
        if splitext(self.filename)[-1] == '.txt':
            return self.read_txt()
        else:
            return self.read_json()

    def read_json(self):
        json_params = json.load(open(self.filename))[basename(self.dataroot)]
        params = {}
        params['samplerate'] = json_params['sample_rate']
        params['gain'] = 1
        params['format'] = json_params['data_format']
        params['dataformat'] = json_params['data_format']
        return params

    def read_txt(self):
        """
        Parses param file
        :return: {dict} dictionary with param file content
        """
        params = {}
        param_file = self.filename

        # we have a file, so open and process it
        with open(param_file, 'r') as f:
            for line in f.readlines():

                stripped_line_list = line.strip().split()
                if len(stripped_line_list) < 2:
                    continue

                # get the columns by splitting
                # param_name, str_to_convert = line.strip().split()[:2]
                param_name, str_to_convert = stripped_line_list[:2]
                try:
                    convert_tuple = self.param_to_convert_fcn[param_name]
                    params[convert_tuple.name] = convert_tuple.convert(str_to_convert)
                except KeyError:
                    pass

        if not 'gain' in params.keys():
            params['gain'] = 1.0
            warnings.warn('Did not find "gain" in the params.txt file. Assuming gain=1.0', RuntimeWarning)

        if not set(params.keys()).issuperset(set(['gain', 'samplerate'])):
            raise ValueError(
                'Params file must contain samplerate and gain!\n' +
                'The following fields were supplied:\n' + str(list(params.keys())))

        return params
示例#23
0
class MorletWaveletFilter(PropertiedObject):
    _descriptors = [
        TypeValTuple('freqs', np.ndarray, np.array([], dtype=np.float)),
        TypeValTuple('time_axis_index', int, -1),
        TypeValTuple(
            'bipolar_pairs', np.recarray,
            np.recarray((0, ), dtype=[('ch0', '|S3'), ('ch1', '|S3')])),
        TypeValTuple('resamplerate', float, -1),
        TypeValTuple('output', str, '')
    ]

    def __init__(self, time_series, **kwds):

        self.window = None
        self.time_series = time_series
        self.init_attrs(kwds)

        self.compute_power_and_phase_fcn = None
        if self.output == 'power':
            self.compute_power_and_phase_fcn = self.compute_power
        elif self.output == 'phase':
            self.compute_power_and_phase_fcn = self.compute_phase
        else:
            self.compute_power_and_phase_fcn = self.compute_power_and_phase

    def all_but_time_iterator(self, array):
        from itertools import product
        sizes_except_time = np.asarray(array.shape)[:-1]
        ranges = map(lambda size: xrange(size), sizes_except_time)
        for cart_prod_idx_tuple in product(*ranges):
            yield cart_prod_idx_tuple, array[cart_prod_idx_tuple]

    def bipolar_iterator(self, array):
        from itertools import product
        sizes_except_time = np.asarray(array.shape)[:-1]

        # depending on the reader, channel axis may be a rec array or a simple array
        # we are interested in an array that has channel labels

        time_series_channel_axis = self.time_series['channels'].data
        try:
            time_series_channel_axis = time_series_channel_axis['name']
        except (KeyError, IndexError):
            pass

        ranges = [
            xrange(len(self.bipolar_pairs)),
            xrange(len(self.time_series['events']))
        ]
        for cart_prod_idx_tuple in product(*ranges):
            b, e = cart_prod_idx_tuple[0], cart_prod_idx_tuple[1]
            bp_pair = self.bipolar_pairs[b]

            ch0 = self.time_series.isel(
                channels=(time_series_channel_axis == bp_pair['ch0']),
                events=e).values
            ch1 = self.time_series.isel(
                channels=(time_series_channel_axis == bp_pair['ch1']),
                events=e).values

            yield cart_prod_idx_tuple, np.squeeze(ch0 - ch1)

    def resample_time_axis(self):
        from ptsa.data.filters.ResampleFilter import ResampleFilter

        rs_time_axis = None  # resampled time axis
        if self.resamplerate > 0:

            rs_time_filter = ResampleFilter(resamplerate=self.resamplerate)
            rs_time_filter.set_input(self.time_series[0, 0, :])
            time_series_resampled = rs_time_filter.filter()
            rs_time_axis = time_series_resampled['time']
        else:
            rs_time_axis = self.time_series['time']

        return rs_time_axis

    def allocate_output_arrays(self, time_axis_size):
        array_type = np.float32
        # if self.output not in ('phase', 'power'):
        #     array_type = np.float32

        if len(self.bipolar_pairs):
            shape = (self.bipolar_pairs.shape[0],
                     self.time_series['events'].shape[0], self.freqs.shape[0],
                     time_axis_size)
        else:
            shape = self.time_series.shape[:-1] + (
                self.freqs.shape[0],
                time_axis_size,
            )

        if self.output == 'power':
            return np.empty(shape=shape, dtype=array_type), None
        elif self.output == 'phase':
            return None, np.empty(shape=shape, dtype=array_type)
        else:
            return np.empty(shape=shape,
                            dtype=array_type), np.empty(shape=shape,
                                                        dtype=array_type)

    def resample_power_and_phase(self, pow_array_single, phase_array_single,
                                 num_points):

        resampled_pow_array = None
        resampled_phase_array = None

        if self.resamplerate > 0.0:
            if pow_array_single is not None:
                resampled_pow_array = resample(pow_array_single,
                                               num=num_points)
            if phase_array_single is not None:
                resampled_phase_array = resample(phase_array_single,
                                                 num=num_points)

        else:
            resampled_pow_array = pow_array_single
            resampled_phase_array = phase_array_single

        return resampled_pow_array, resampled_phase_array

    def compute_power(self, wavelet_coef_array):
        return wavelet_coef_array.real**2 + wavelet_coef_array.imag**2, None

    def compute_phase(self, wavelet_coef_array):
        return None, np.angle(wavelet_coef_array)

    def compute_power_and_phase(self, wavelet_coef_array):
        return wavelet_coef_array.real**2 + wavelet_coef_array.imag**2, np.angle(
            wavelet_coef_array)

    def store_power_and_phase(self, idx_tuple, power_array, phase_array,
                              power_array_single, phase_array_single):

        if power_array_single is not None:
            power_array[idx_tuple] = power_array_single
        if phase_array_single is not None:
            phase_array[idx_tuple] = phase_array_single

    def get_data_iterator(self):
        if len(self.bipolar_pairs):
            data_iterator = self.bipolar_iterator(self.time_series)
        else:
            data_iterator = self.all_but_time_iterator(self.time_series)

        return data_iterator

    def construct_output_array(self, array, dims, coords):
        out_array = xray.DataArray(array, dims=dims, coords=coords)
        out_array.attrs['samplerate'] = self.time_series.attrs['samplerate']
        if self.resamplerate > 0.0:
            out_array.attrs['samplerate'] = self.time_series.attrs[
                'samplerate']

        return out_array

    def build_output_arrays(self, wavelet_pow_array, wavelet_phase_array,
                            time_axis):
        wavelet_pow_array_xray = None
        wavelet_phase_array_xray = None

        if isinstance(self.time_series, xray.DataArray):

            if len(self.bipolar_pairs):
                dims = ['bipolar_pairs', 'events', 'frequency', 'time']
                coords = [
                    self.bipolar_pairs, self.time_series['events'], self.freqs,
                    time_axis
                ]
            else:
                dims = list(self.time_series.dims[:-1] + (
                    'frequency',
                    'time',
                ))
                coords = [
                    self.time_series.coords[dim_name]
                    for dim_name in self.time_series.dims[:-1]
                ]
                coords.append(self.freqs)
                coords.append(time_axis)

            if wavelet_pow_array is not None:
                wavelet_pow_array_xray = self.construct_output_array(
                    wavelet_pow_array, dims=dims, coords=coords)
            if wavelet_phase_array is not None:
                wavelet_phase_array_xray = self.construct_output_array(
                    wavelet_phase_array, dims=dims, coords=coords)

            return wavelet_pow_array_xray, wavelet_phase_array_xray

    def compute_wavelet_ffts(self):

        samplerate = self.time_series.attrs['samplerate']

        freqs = np.atleast_1d(self.freqs)

        wavelets = morlet_multi(freqs=freqs, widths=5, samplerates=samplerate)
        # ADD WARNING HERE FROM PHASE_MULTI

        num_wavelets = len(wavelets)

        # computting length of the longest wavelet
        s_w = max(map(lambda wavelet: wavelet.shape[0], wavelets))
        # length of the tie axis of the time series
        s_d = self.time_series['time'].shape[0]

        # determine the size based on the next power of 2
        convolution_size = s_w + s_d - 1
        convolution_size_pow2 = np.power(2, next_pow2(convolution_size))

        # preallocating arrays
        wavelet_fft_array = np.empty(shape=(num_wavelets,
                                            convolution_size_pow2),
                                     dtype=np.complex64)

        convolution_size_array = np.empty(shape=(num_wavelets), dtype=np.int)

        # computting wavelet ffts
        for i, wavelet in enumerate(wavelets):
            wavelet_fft_array[i] = fft(wavelet, convolution_size_pow2)
            convolution_size_array[i] = wavelet.shape[0] + s_d - 1

        return wavelet_fft_array, convolution_size_array, convolution_size_pow2

    def filter(self):

        data_iterator = self.get_data_iterator()

        time_axis = self.resample_time_axis()
        time_axis_size = time_axis.shape[0]

        wavelet_pow_array, wavelet_phase_array = self.allocate_output_arrays(
            time_axis_size=time_axis_size)

        # preallocating array
        wavelet_coef_single_array = np.empty(shape=(time_axis.shape[0]),
                                             dtype=np.complex64)

        wavelet_fft_array, convolution_size_array, convolution_size_pow2 = self.compute_wavelet_ffts(
        )
        num_wavelets = wavelet_fft_array.shape[0]

        wavelet_start = time.time()

        for idx_tuple, signal in data_iterator:

            signal_fft = fft(signal, convolution_size_pow2)

            for w in xrange(num_wavelets):

                signal_wavelet_conv = ifft(wavelet_fft_array[w] * signal_fft)

                # computting trim indices for the wavelet_coeff array
                start_offset = (convolution_size_array[w] - time_axis_size) / 2
                end_offset = start_offset + time_axis_size

                wavelet_coef_single_array[:] = signal_wavelet_conv[
                    start_offset:end_offset]

                out_idx_tuple = idx_tuple + (w, )

                pow_array_single, phase_array_single = self.compute_power_and_phase_fcn(
                    wavelet_coef_single_array)

                self.resample_power_and_phase(pow_array_single,
                                              phase_array_single,
                                              num_points=time_axis_size)

                self.store_power_and_phase(out_idx_tuple, wavelet_pow_array,
                                           wavelet_phase_array,
                                           pow_array_single,
                                           phase_array_single)

        print 'total time wavelet loop: ', time.time() - wavelet_start
        return self.build_output_arrays(wavelet_pow_array, wavelet_phase_array,
                                        time_axis)
class MorletWaveletFilterSimple(PropertiedObject):
    _descriptors = [
        TypeValTuple('freqs', np.ndarray, np.array([], dtype=np.float)),
        TypeValTuple('width', int, 5),
        TypeValTuple('output', str, ''),
        TypeValTuple('frequency_dim_pos', int, -2)
    ]

    def __init__(self, time_series, **kwds):

        self.window = None
        self.time_series = time_series
        self.init_attrs(kwds)

        self.compute_power_and_phase_fcn = None

        if self.output == 'power':
            self.compute_power_and_phase_fcn = self.compute_power
        elif self.output == 'phase':
            self.compute_power_and_phase_fcn = self.compute_phase
        else:
            self.compute_power_and_phase_fcn = self.compute_power_and_phase

    def all_but_time_iterator(self, array):
        from itertools import product
        sizes_except_time = np.asarray(array.shape)[:-1]
        ranges = map(lambda size: xrange(size), sizes_except_time)
        for cart_prod_idx_tuple in product(*ranges):
            yield cart_prod_idx_tuple, array[cart_prod_idx_tuple]

    def resample_time_axis(self):
        from ptsa.data.filters.ResampleFilter import ResampleFilter

        rs_time_axis = None  # resampled time axis
        if self.resamplerate > 0:

            rs_time_filter = ResampleFilter(resamplerate=self.resamplerate)
            rs_time_filter.set_input(self.time_series[0, 0, :])
            time_series_resampled = rs_time_filter.filter()
            rs_time_axis = time_series_resampled['time']
        else:
            rs_time_axis = self.time_series['time']

        return rs_time_axis, self.time_series['time']

    def allocate_output_arrays(self, time_axis_size):
        array_type = np.float32
        shape = self.time_series.shape[:-1] + (
            self.freqs.shape[0],
            time_axis_size,
        )

        if self.output == 'power':
            return np.empty(shape=shape, dtype=array_type), None
        elif self.output == 'phase':
            return None, np.empty(shape=shape, dtype=array_type)
        else:
            return np.empty(shape=shape,
                            dtype=array_type), np.empty(shape=shape,
                                                        dtype=array_type)

    def compute_power(self, wavelet_coef_array):
        # return wavelet_coef_array.real ** 2 + wavelet_coef_array.imag ** 2, None
        return np.abs(wavelet_coef_array)**2, None
        # # wavelet_coef_array.real ** 2 + wavelet_coef_array.imag ** 2, None

    def compute_phase(self, wavelet_coef_array):
        return None, np.angle(wavelet_coef_array)

    def compute_power_and_phase(self, wavelet_coef_array):
        return wavelet_coef_array.real**2 + wavelet_coef_array.imag**2, np.angle(
            wavelet_coef_array)

    def store(self, idx_tuple, target_array, source_array):
        if source_array is not None:
            target_array[idx_tuple] = source_array

    def get_data_iterator(self):
        return self.all_but_time_iterator(self.time_series)

    def construct_output_array(self, array, dims, coords):
        out_array = xr.DataArray(array, dims=dims, coords=coords)
        out_array.attrs['samplerate'] = self.time_series.attrs['samplerate']
        return out_array

    def build_output_arrays(self, wavelet_pow_array, wavelet_phase_array,
                            time_axis):
        wavelet_pow_array_xray = None
        wavelet_phase_array_xray = None

        if isinstance(self.time_series, xr.DataArray):

            dims = list(self.time_series.dims[:-1] + (
                'frequency',
                'time',
            ))

            transposed_dims = []

            # getting frequency dim position as positive integer
            self.frequency_dim_pos = (len(dims) +
                                      self.frequency_dim_pos) % len(dims)
            orig_frequency_idx = dims.index('frequency')

            if self.frequency_dim_pos != orig_frequency_idx:
                transposed_dims = dims[:orig_frequency_idx] + dims[
                    orig_frequency_idx + 1:]
                transposed_dims.insert(self.frequency_dim_pos, 'frequency')

            coords = {
                dim_name: self.time_series.coords[dim_name]
                for dim_name in self.time_series.dims[:-1]
            }
            coords['frequency'] = self.freqs
            coords['time'] = time_axis

            if wavelet_pow_array is not None:
                wavelet_pow_array_xray = self.construct_output_array(
                    wavelet_pow_array, dims=dims, coords=coords)
            if wavelet_phase_array is not None:
                wavelet_phase_array_xray = self.construct_output_array(
                    wavelet_phase_array, dims=dims, coords=coords)

            if wavelet_pow_array_xray is not None:
                wavelet_pow_array_xray = TimeSeriesX(wavelet_pow_array_xray)
                if len(transposed_dims):
                    wavelet_pow_array_xray = wavelet_pow_array_xray.transpose(
                        *transposed_dims)

                wavelet_pow_array_xray.attrs = self.time_series.attrs.copy()

            if wavelet_phase_array_xray is not None:
                wavelet_phase_array_xray = TimeSeriesX(
                    wavelet_phase_array_xray)
                if len(transposed_dims):
                    wavelet_phase_array_xray = wavelet_phase_array_xray.transpose(
                        *transposed_dims)

                wavelet_phase_array_xray.attrs = self.time_series.attrs.copy()

            return wavelet_pow_array_xray, wavelet_phase_array_xray

    def compute_wavelet_ffts(self):

        samplerate = self.time_series.attrs['samplerate']

        freqs = np.atleast_1d(self.freqs)

        wavelets = morlet_multi(freqs=freqs,
                                widths=self.width,
                                samplerates=samplerate)
        # ADD WARNING HERE FROM PHASE_MULTI

        num_wavelets = len(wavelets)

        # computting length of the longest wavelet
        s_w = max(map(lambda wavelet: wavelet.shape[0], wavelets))
        # length of the tie axis of the time series
        s_d = self.time_series['time'].shape[0]

        # determine the size based on the next power of 2
        convolution_size = s_w + s_d - 1

        # next power of two
        convolution_size_pow2 = 2**int(np.ceil(np.log2(convolution_size)))

        # convolution_size_pow2 = np.power(2, next_pow2(convolution_size))

        # preallocating arrays
        # wavelet_fft_array = np.empty(shape=(num_wavelets, convolution_size_pow2), dtype=np.complex64)
        # wavelet_fft_array = np.empty(shape=(num_wavelets, convolution_size_pow2), dtype=np.complex)
        wavelet_fft_array = []
        convolution_size_pow2 = []

        convolution_size_array = np.empty(shape=(num_wavelets), dtype=np.int)

        # # computting wavelet ffts
        # for i, wavelet in enumerate(wavelets):
        #     wavelet_fft_array[i] = fft(wavelet, convolution_size_pow2)
        #     convolution_size_array[i] = wavelet.shape[0] + s_d - 1

        # computting wavelet ffts
        for i, wavelet in enumerate(wavelets):
            s_w = wavelet.shape[0]
            convolution_size = s_w + s_d - 1
            convolution_size_pow2.append(2**int(
                np.ceil(np.log2(convolution_size))))
            wavelet_fft_array.append(fft(wavelet, convolution_size_pow2[-1]))
            # wavelet_fft_array[i] = fft(wavelet, convolution_size_pow2)
            convolution_size_array[i] = convolution_size

        return wavelet_fft_array, convolution_size_array, convolution_size_pow2

    def filter(self):

        data_iterator = self.get_data_iterator()

        time_axis = self.time_series['time']

        time_axis_size = time_axis.shape[0]

        wavelet_pow_array, wavelet_phase_array = self.allocate_output_arrays(
            time_axis_size=time_axis_size)

        # preallocating array
        wavelet_coef_single_array = np.empty(shape=(time_axis_size),
                                             dtype=np.complex64)

        wavelet_fft_array, convolution_size_array, convolution_size_pow2 = self.compute_wavelet_ffts(
        )
        # num_wavelets = wavelet_fft_array.shape[0]
        num_wavelets = len(wavelet_fft_array)
        wavelet_start = time.time()

        for idx_tuple, signal in data_iterator:

            # signal_fft = fft(signal, convolution_size_pow2)

            for w in xrange(num_wavelets):

                signal_fft = fft(signal, convolution_size_pow2[w])
                signal_wavelet_conv = ifft(wavelet_fft_array[w] * signal_fft)

                # computting trim indices for the wavelet_coeff array
                start_offset = (convolution_size_array[w] - time_axis_size) / 2
                end_offset = start_offset + time_axis_size

                wavelet_coef_single_array[:] = signal_wavelet_conv[
                    start_offset:end_offset]

                out_idx_tuple = idx_tuple + (w, )

                pow_array_single, phase_array_single = self.compute_power_and_phase_fcn(
                    wavelet_coef_single_array)

                self.store(out_idx_tuple, wavelet_pow_array, pow_array_single)
                self.store(out_idx_tuple, wavelet_phase_array,
                           phase_array_single)

        print('total time wavelet loop: ', time.time() - wavelet_start)
        return self.build_output_arrays(wavelet_pow_array, wavelet_phase_array,
                                        time_axis)
示例#25
0
class ResampleFilter(PropertiedObject, BaseFilter):
    '''
    Resample Filter
    '''

    _descriptors = [
        # TypeValTuple('time_series', np.ndarray, np.array([0.0])),
        TypeValTuple('time_series', TimeSeriesX,
                     TimeSeriesX([0.0], dims=['time'])),
        # TypeValTuple('time_series', np.ndarray, np.array([0.0])),
        TypeValTuple('resamplerate', float, -1.0),
        TypeValTuple('time_axis_index', int, -1),
        TypeValTuple('round_to_original_timepoints', bool, False),
    ]

    # def __aaa(self):
    #     self.resamplerate = None
    #     self.time_axis_index = None
    #     self.round_to_original_timepoints = None
    #     # self.round_to_original_timepoints = None
    #     # setattr(self,'round_to_original_timepoints',None)

    def ___syntax_helper(self):
        self.time_series = None
        self.resamplerate = None
        self.time_axis_index = None
        self.round_to_original_timepoints = None

    def __init__(self, **kwds):
        '''

        :param kwds: allowed values are:
        -------------------------------------
        :param resamplerate - new sampling frequency
        :param time_series - TimeSeriesX object
        :param time_axis_index - index of the time axis
        :param round_to_original_timepoints  -  boolean flag indicating if timepoints from original time axis
        should be reused after proper rounding. Default setting is False
        -------------------------------------
        :return:
        '''
        self.window = None
        # self.time_series = None
        self.init_attrs(kwds)

    def filter(self):
        '''
        resamples time series
        :return:resampled time series with sampling frequency set to resamplerate
        '''
        # samplerate = self.time_series.attrs['samplerate']
        samplerate = float(self.time_series['samplerate'])

        time_axis_length = np.squeeze(self.time_series.coords['time'].shape)
        new_length = int(
            np.round(time_axis_length * self.resamplerate / samplerate))

        print new_length

        if self.time_axis_index < 0:
            self.time_axis_index = self.time_series.get_axis_num('time')

        time_axis = self.time_series.coords[self.time_series.dims[
            self.time_axis_index]]

        try:
            time_axis_data = time_axis.data[
                'time']  # time axis can be recarray with one of the arrays being time
        except (KeyError, IndexError) as excp:
            # if we get here then most likely time axis is ndarray of floats
            time_axis_data = time_axis.data

        time_idx_array = np.arange(len(time_axis))

        if self.round_to_original_timepoints:
            filtered_array, new_time_idx_array = resample(
                self.time_series.data,
                new_length,
                t=time_idx_array,
                axis=self.time_axis_index,
                window=self.window)

            # print new_time_axis

            new_time_idx_array = np.rint(new_time_idx_array).astype(np.int)

            new_time_axis = time_axis[new_time_idx_array]

        else:
            filtered_array, new_time_axis = resample(self.time_series.data,
                                                     new_length,
                                                     t=time_axis_data,
                                                     axis=self.time_axis_index,
                                                     window=self.window)

        coords = []
        for i, dim_name in enumerate(self.time_series.dims):
            if i != self.time_axis_index:
                coords.append(self.time_series.coords[dim_name].copy())
            else:
                coords.append((dim_name, new_time_axis))

        filtered_time_series = xray.DataArray(filtered_array, coords=coords)
        # filtered_time_series.attrs['samplerate'] = self.resamplerate
        filtered_time_series['samplerate'] = self.resamplerate
        return TimeSeriesX(filtered_time_series)