def bandpass(self, data):
     # raise for some bad scenarios
     if self.high - 1.0 > -1e-6:
         msg = ("Selected high corner frequency ({}) of bandpass is at or "
                "above Nyquist ({}). Applying a high-pass instead.").format(
                    self.freqmax, self.fe)
         logger.warning(msg)
         #warnings.warn(msg)
         return highpass(data,
                         freq=self.freqmin,
                         df=self.sampling_rate,
                         corners=self.corners,
                         zerophase=self.zerophase)
     if self.low > 1:
         msg = "Selected low corner frequency is above Nyquist."
         raise ValueError(msg)
     if self.zi is None:
         z, p, k = iirfilter(self.corners, [self.low, self.high],
                             btype='bandpass',
                             ftype='butter',
                             output='zpk')
         self.sos = zpk2sos(z, p, k)
         self.zi = sosfilt_zi(self.sos)
     data, self.zi = sosfilt(self.sos, data, zi=self.zi)
     return data
Exemple #2
0
def butter_sosfilt(X, stopband, fs, order=6, axis=-2, zi=None, passband=None, verb=True):
    ''' use a (cascade of) butterworth SOS filter(s) to band-pass and (cascade of) band stop X along axis '''
    if axis < 0: # no neg axis
        axis = X.ndim+axis
    # TODO []: auto-order determination?
    sos = butter_sosfilt_sos(stopband, fs, order, passband=passband)
    sos = sos.astype(X.dtype) # keep as single precision

    if axis == X.ndim-2 and zi is None:
        zi = sosfilt_zi(sos) # (order,2)
        zi.astype(X.dtype)
        zi = sosfilt_zi_warmup(zi, X, axis, sos)

    else:
        zi = None
        print("Warning: not warming up...")

    # Apply the warmed up filter to the input data
    #print("zi={}".format(zi.shape))
    if not zi is None:
        #print("filt:zi X{} axis={}".format(X.shape,axis))
        X, zi  = sosfilt(sos, X, axis=axis, zi=zi)
    else:
        print("filt:no-zi")
        X  = sosfilt(sos, X, axis=axis) # zi=zi)

    # return filtered data, filter-coefficients, filter-state
    return (X, sos, zi)
    def __init__(self, output_count, dtype, color):
        """
        Parameters
        ----------
        output_count : int
            number of channels to generate
        dtype : str or numpy.dtype or type
            data type to generate
        color : str
            coloration of noise to generate
        """
        super().__init__(output_count=output_count, dtype=dtype)
        color = color.upper()

        # initialize constants
        self._GAIN_FACTOR = 20
        # fmt: off
        self._B_PINK = np.array([0.049922035, -0.095993537, 0.050612699, -0.004408786], dtype=dtype)
        self._A_PINK = np.array([1, -2.494956002, 2.017265875, -0.522189400], dtype=dtype)
        # "Consider designing filters in ZPK format and converting directly to SOS."
        # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2sos.html
        self._SOS_EM = np.array(
            [[0.0248484318401191, -0.0430465879719351, 0.0185894592446002, 1, -1.83308400613427,
              0.833084457582538],
             [1.09742855358091, -2, 0.902572344822294, 1, -1.84861597668780, 0.849007279800584],
             [1.32049919002049, -2, 1.27062183754708, 1, -1.51266987481441, 0.958710654962438],
             [2, -1.23789744854974, 0.693137522016399, 1, -0.555672516031578, 0.356572277622976],
             [2, -0.127412449936323, 0.451198075878658, 1, -0.0464446484127454, 0.0651312831643292],
             [0.295094951044653, 0.0954033015853709, 0, 1, 0, 0]],
            dtype=dtype
        )
        # fmt: on

        # pick utilized coefficients
        if color == "PINK":
            a = self._A_PINK
            self._sos = tf2sos(b=self._B_PINK, a=self._A_PINK).astype(dtype)
            # "It is generally discouraged to convert from TF to SOS format, since doing so
            # usually will not improve numerical precision errors."
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2sos.html
        elif color in ["EM", "EIGENMIKE"]:
            [_, a] = sos2tf(sos=self._SOS_EM)
            self._sos = self._SOS_EM
        else:
            raise NotImplementedError(
                f'chosen noise generator color "{color}" not implemented yet.'
            )

        # initialize IIR filter delay conditions
        self._last_delays = sosfilt_zi(sos=self._sos).astype(dtype)
        # adjust in middle axis according to output count
        self._last_delays = np.repeat(
            self._last_delays[:, np.newaxis, :], repeats=output_count, axis=1
        )

        # approximate decay time to skip transient response part of IIR filter, according to [1]
        t60_samples = int(np.log(1000.0) / (1.0 - np.abs(np.roots(a)).max())) + 1
        # generate and discard long enough sequence to skip IIR transient response (decay time)
        _ = self.generate_block(block_length=t60_samples)
Exemple #4
0
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None):
    """Filtfilt version using Second Order sections. Code is taken from scipy.signal.filtfilt and adapted to make it work with SOS.
    Note that broadcasting does not work.
    """
    from scipy.signal import sosfilt_zi
    from scipy.signal._arraytools import odd_ext, axis_slice, axis_reverse
    x = np.asarray(x)

    if padlen is None:
        edge = 0
    else:
        edge = padlen

    # x's 'axis' dimension must be bigger than edge.
    if x.shape[axis] <= edge:
        raise ValueError("The length of the input vector x must be at least "
                         "padlen, which is %d." % edge)

    if padtype is not None and edge > 0:
        # Make an extension of length `edge` at each
        # end of the input array.
        if padtype == 'even':
            ext = even_ext(x, edge, axis=axis)
        elif padtype == 'odd':
            ext = odd_ext(x, edge, axis=axis)
        else:
            ext = const_ext(x, edge, axis=axis)
    else:
        ext = x

    # Get the steady state of the filter's step response.
    zi = sosfilt_zi(sos)

    # Reshape zi and create x0 so that zi*x0 broadcasts
    # to the correct value for the 'zi' keyword argument
    # to lfilter.
    #zi_shape = [1] * x.ndim
    #zi_shape[axis] = zi.size
    #zi = np.reshape(zi, zi_shape)
    x0 = axis_slice(ext, stop=1, axis=axis)
    # Forward filter.
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x0)

    # Backward filter.
    # Create y0 so zi*y0 broadcasts appropriately.
    y0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y0)

    # Reverse y.
    y = axis_reverse(y, axis=axis)

    if edge > 0:
        # Slice the actual signal from the extended signal.
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)

    return y
Exemple #5
0
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None, method='pad', irlen=None):
    """Filtfilt version using Second Order sections. Code is taken from scipy.signal.filtfilt and adapted to make it work with SOS.
    Note that broadcasting does not work.
    """
    from scipy.signal import sosfilt_zi
    from scipy.signal._arraytools import odd_ext, axis_slice, axis_reverse
    x = np.asarray(x)
    
    if padlen is None:
        edge = 0
    else:
        edge = padlen

    # x's 'axis' dimension must be bigger than edge.
    if x.shape[axis] <= edge:
        raise ValueError("The length of the input vector x must be at least "
                         "padlen, which is %d." % edge)

    if padtype is not None and edge > 0:
        # Make an extension of length `edge` at each
        # end of the input array.
        if padtype == 'even':
            ext = even_ext(x, edge, axis=axis)
        elif padtype == 'odd':
            ext = odd_ext(x, edge, axis=axis)
        else:
            ext = const_ext(x, edge, axis=axis)
    else:
        ext = x

    # Get the steady state of the filter's step response.
    zi = sosfilt_zi(sos)

    # Reshape zi and create x0 so that zi*x0 broadcasts
    # to the correct value for the 'zi' keyword argument
    # to lfilter.
    #zi_shape = [1] * x.ndim
    #zi_shape[axis] = zi.size
    #zi = np.reshape(zi, zi_shape)
    x0 = axis_slice(ext, stop=1, axis=axis)
    # Forward filter.
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x0)

    # Backward filter.
    # Create y0 so zi*y0 broadcasts appropriately.
    y0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y0)

    # Reverse y.
    y = axis_reverse(y, axis=axis)

    if edge > 0:
        # Slice the actual signal from the extended signal.
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)

    return y
    def butter_bandpass_sosfiltfilt(self,
                                    lowcut=None,
                                    highcut=None,
                                    accel_axis="z",
                                    order=5,
                                    axis=-1,
                                    padtype="odd",
                                    padlen=None,
                                    method='pad',
                                    irlen=None):
        '''Filtfilt version using Second Order sections. Code is taken from scipy.signal.filtfilt and adapted to make it work with
        sos. Note that broadcasting does not work'''
        data = np.asarray(getattr(self.df, accel_axis[:]))
        sos = self.butter_bandpass(lowcut, highcut, order)

        if padlen is None:
            edge = 0
        else:
            edge = padlen

        if data.shape[axis] <= edge:
            raise ValueError(
                "The length of the input vector x must be at least padlen, which is %d."
                % edge)

        if padtype is not None and edge > 0:
            if padtype == "even":
                ext = even_ext(data, edge, axis=axis)
            elif padtype == "odd":
                ext = odd_ext(data, edge, axis=axis)
            else:
                ext = const_ext(data, edge, axis=axis)
        else:
            ext = data

        # Get the steady state of the filter's first step resopnse
        zi = sosfilt_zi(sos)

        # Reshape zi and create x0 so that zi*x0 broadcasts to the correct value for the zi keyword argument to lfilter
        x0 = axis_slice(ext, stop=1, axis=axis)
        # Forward filter
        (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x0)

        y0 = axis_slice(y, start=-1, axis=axis)
        # Backward filter
        (y, zf) = sosfilt(sos,
                          axis_reverse(y, axis=axis),
                          axis=axis,
                          zi=zi * y0)
        y = axis_reverse(y, axis=axis)

        if edge > 0:
            y = axis_slice(y, start=edge, stop=-edge, axis=axis)

        return y
Exemple #7
0
def save_butter_sosfilt_coeff(filename=None, stopband=((0,5),(25,-1)), fs=200, order=6):
    ''' design a butterworth sos filter cascade and save the coefficients '''
    import pickle
    sos = butter_sosfilt_sos(stopband, fs, order, passband=None)
    zi = sosfilt_zi(sos)
    if filename is None:
        # auto-generate descriptive filename
        filename = "butter_stopband{}_fs{}.pk".format(stopband,fs)
    with open(filename,'wb') as f:
        pickle.dump(sos,f)
        pickle.dump(zi,f)
        f.close()
Exemple #8
0
    def add_series(self, chan_info):
        # Plot for this channel
        glw = self.findChild(pg.GraphicsLayoutWidget)
        new_plot = glw.addPlot(row=len(self.segmented_series),
                               col=0,
                               title=chan_info['label'],
                               enableMenu=False)
        new_plot.setMouseEnabled(x=False, y=False)

        # Appearance settings
        my_theme = THEMES[self.plot_config['theme']]
        self.plot_config['color_iterator'] = (
            self.plot_config['color_iterator'] + 1) % len(
                my_theme['pencolors'])
        pen_color = QColor(
            my_theme['pencolors'][self.plot_config['color_iterator']])

        # Prepare plot data
        samples_per_segment = int(
            np.ceil(self.plot_config['x_range'] * self.samplingRate /
                    self.plot_config['n_segments']))
        for ix in range(self.plot_config['n_segments']):
            if ix < (self.plot_config['n_segments'] - 1):
                seg_x = np.arange(ix * samples_per_segment,
                                  (ix + 1) * samples_per_segment,
                                  dtype=np.int16)
            else:
                # Last segment might not be full length.
                seg_x = np.arange(ix * samples_per_segment,
                                  int(self.plot_config['x_range'] *
                                      self.samplingRate),
                                  dtype=np.int16)
            if self.plot_config['downsample']:
                seg_x = seg_x[::DSFAC]
            c = new_plot.plot(parent=new_plot, pen=pen_color)  # PlotDataItem
            c.setData(x=seg_x, y=np.zeros_like(seg_x))  # Pre-fill.

        # Add threshold line
        thresh_line = pg.InfiniteLine(angle=0, movable=True)
        thresh_line.sigPositionChangeFinished.connect(
            self.on_thresh_line_moved)
        new_plot.addItem(thresh_line)

        self.segmented_series[chan_info['label']] = {
            'chan_id': chan_info['chan'],
            'line_ix': len(self.segmented_series),
            'plot': new_plot,
            'last_sample_ix': -1,
            'thresh_line': thresh_line,
            'hp_zi': signal.sosfilt_zi(self.plot_config['hp_sos']),
            'ln_zi': None
        }
def apply_sos(signal, sos, states=None, axis=0):
    r"""Filter the data along one dimension using second order sections.

    Filter the input data using a digital IIR filter defined by sos.

    Parameters:
    -----------
    signal : array like
        The input signal
    sos : array like
        Array of second-order filter coefficients, must have shape
        (n_sections, 6). Each row corresponds to a second-order
        section, with the first three columns providing the numerator
        coefficients and the last three providing the denominator
        coefficients.
    states : True, None or array_like, optional
        Inital conditions for the filter. if True, the conditions for
        a step response are constructed. if set to None, the inital rest is
        assumed (all 0). Otherwise, expects the inital filter delay
        values.

    Returns:
    --------
    sig_out : ndarray
        The output of the digital filter
    states : ndarray
        the final filter delay values

    """
    _, _, n_channel = audio._duration_is_signal(signal, None, None)

    # initialize states
    if states is True:
        states = sig.sosfilt_zi(sos)
        dim = signal.shape[1:][::-1]
        states = np.tile(states.T, (*dim, 1, 1)).T
    elif states is None:
        order = sos.shape[0]
        if np.ndim(n_channel) == 0:
            if n_channel == 1:  # only one channel
                shape = [order, 2]
            else:               # more then one channels
                shape = [order, 2, n_channel]
        else:                   # Multiple dimensions
            shape = [order, 2, *n_channel]
        states = np.zeros(shape)

    sig_out, states = sig.sosfilt(sos, signal, zi=states, axis=axis)

    return sig_out, states
Exemple #10
0
    def test_Butterworth(self):
        # Ensure the realtime filter behaves like scipy.sosfilt
        x = (np.arange(250) < 100).astype(int).tolist()
        sos = signal.butter(4, 0.1, output='sos')
        zi = signal.sosfilt_zi(sos)
        y, _ = signal.sosfilt(sos, x, zi=zi)

        test_filter = filters.Butterworth(N=4, Wn=0.1)
        # Simulate data coming in real time
        y_real_time_filter = []
        for new_val in x:
            y_real_time_filter.append(test_filter.filter(new_val))

        self.assertListEqual(y.tolist(), y_real_time_filter)
Exemple #11
0
    def process_data(self, data_in):
        """
        Call this function passing new data to filter data as it comes in.
        :param data_in: single float or iterable.
        :return: np.array with filtered data (or single element if input is single element).
        """
        # Ensure input is a numpy array
        data_in = self.to_iter(data_in)
        # Check if this is the first run
        if self.z is None:
            self.z = sosfilt_zi(sos=self.sos) * data_in[0]
            # self.z = lfilter_zi(a=self.a, b=self.b) * data_in[0]
        # Apply the filter
        data_out, self.z = sosfilt(sos=self.sos, x=data_in, zi=self.z)
        # data_out, self.z = lfilter(a=self.a, b=self.b, x=data_in, zi=self.z)

        return self.to_single(data_out)
Exemple #12
0
    def init_bandpass_filter(self, low, high):
        """
        Initialize the bandpass filter. The filter is a butter filter of order
        neurodecode.stream_viewer._scope.BP_ORDER

        Parameters
        ----------
        low : int | float
            The frequency at which the signal is high-passed.
        high : int | float
            The frequency at which the signal is low-passed.
        """
        self.bp_low = low / (0.5 * self.sample_rate)
        self.bp_high = high / (0.5 * self.sample_rate)
        self.sos = butter(BP_ORDER, [self.bp_low, self.bp_high],
                          btype='band', output='sos')
        self.zi_coeff = sosfilt_zi(self.sos).reshape((self.sos.shape[0], 2, 1))
        self.zi = None
Exemple #13
0
 def __init__(self, N: int, Wn: float, btype='low', fs=None):
     ''' 
     N: order
     Wn: (default) normalized cutoff freq (cutoff freq / Nyquist freq). If fs is passed, cutoff is in freq.
     btyple: 'low', 'high', or 'bandpass'
     fs: Optional: sample freq, Hz. If not None, Wn describes the cutoff freq in Hz
     '''
     self.N = N
     if fs is not None:
         self.Wn = Wn / (fs / 2)
     else:
         self.Wn = Wn
     self.btype = btype
     self.sos = signal.butter(N=self.N,
                              Wn=self.Wn,
                              btype=self.btype,
                              output='sos')
     self.zi = signal.sosfilt_zi(self.sos)
Exemple #14
0
    def update(self):

        # copy the meta
        self.o = self.i

        # When we have not received data, there is nothing to do
        if not self.i.ready():
            return

        # At this point, we are sure that we have some data to process
        if self._columns is None:
            self._columns = self.i.data.columns

        # set rate from the data if it is not yet given
        if self._rate is None:
            self._rate = self.i.meta.get("rate", None)
            if self._rate is None:
                # If there is no rate in the meta, set rate to 1.0
                self._rate = 1.0
                self.logger.warning(
                    f"Nominal rate not supplied, considering " f"1.0 Hz instead. "
                )
            else:
                self.logger.info(f"Nominal rate set to {self._rate}. ")

        if self._sos is None:
            self._design_sos()
        if self._zi is None:
            zi0 = signal.sosfilt_zi(self._sos)
            self._zi = np.stack(
                [
                    (zi0 * self.i.data.iloc[0, k_col])
                    for k_col in range(len(self._columns))
                ],
                axis=1,
            )
        port_o, self._zi = signal.sosfilt(self._sos, self.i.data.values.T, zi=self._zi)
        self.o.data = pd.DataFrame(
            port_o.T, columns=self._columns, index=self.i.data.index
        )
Exemple #15
0
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
    """Do SciPy sosfiltfilt."""
    from scipy.signal import sosfilt, sosfilt_zi
    sos, n_sections = _validate_sos(sos)

    # `method` is "pad"...
    ntaps = 2 * n_sections + 1
    ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
    edge, ext = _validate_pad(padtype, padlen, x, axis, ntaps=ntaps)

    # These steps follow the same form as filtfilt with modifications
    zi = sosfilt_zi(sos)  # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
    zi_shape = [1] * x.ndim
    zi_shape[axis] = 2
    zi.shape = [n_sections] + zi_shape
    x_0 = axis_slice(ext, stop=1, axis=axis)
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
    y_0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
    y = axis_reverse(y, axis=axis)
    if edge > 0:
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)
    return y
Exemple #16
0
def _sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
    """Do SciPy sosfiltfilt."""
    from scipy.signal import sosfilt, sosfilt_zi
    sos, n_sections = _validate_sos(sos)

    # `method` is "pad"...
    ntaps = 2 * n_sections + 1
    ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
    edge, ext = _validate_pad(padtype, padlen, x, axis,
                              ntaps=ntaps)

    # These steps follow the same form as filtfilt with modifications
    zi = sosfilt_zi(sos)  # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
    zi_shape = [1] * x.ndim
    zi_shape[axis] = 2
    zi.shape = [n_sections] + zi_shape
    x_0 = axis_slice(ext, stop=1, axis=axis)
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
    y_0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
    y = axis_reverse(y, axis=axis)
    if edge > 0:
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)
    return y
Exemple #17
0
def sosfiltfilt(sos, x, axis=-1, padtype='odd', padlen=None):
    """
    A forward-backward filter using cascaded second-order sections.
    See `filtfilt` for more complete information about this method.
    Parameters
    ----------
    sos : array_like
        Array of second-order filter coefficients, must have shape
        ``(n_sections, 6)``. Each row corresponds to a second-order
        section, with the first three columns providing the numerator
        coefficients and the last three providing the denominator
        coefficients.
    x : array_like
        The array of data to be filtered.
    axis : int, optional
        The axis of `x` to which the filter is applied.
        Default is -1.
    padtype : str or None, optional
        Must be 'odd', 'even', 'constant', or None.  This determines the
        type of extension to use for the padded signal to which the filter
        is applied.  If `padtype` is None, no padding is used.  The default
        is 'odd'.
    padlen : int or None, optional
        The number of elements by which to extend `x` at both ends of
        `axis` before applying the filter.  This value must be less than
        ``x.shape[axis] - 1``.  ``padlen=0`` implies no padding.
        The default value is::
            3 * (2 * len(sos) + 1 - min((sos[:, 2] == 0).sum(),
                                        (sos[:, 5] == 0).sum()))
        The extra subtraction at the end attempts to compensate for poles
        and zeros at the origin (e.g. for odd-order filters) to yield
        equivalent estimates of `padlen` to those of `filtfilt` for
        second-order section filters built with `scipy.signal` functions.
    Returns
    -------
    y : ndarray
        The filtered output with the same shape as `x`.
    See Also
    --------
    filtfilt, sosfilt, sosfilt_zi
    Notes
    -----
    .. versionadded:: 0.18.0
    """
    sos, n_sections = _validate_sos(sos)

    # `method` is "pad"...
    ntaps = 2 * n_sections + 1
    ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
    edge, ext = _validate_pad(padtype, padlen, x, axis, ntaps=ntaps)

    # These steps follow the same form as filtfilt with modifications
    zi = sosfilt_zi(sos)  # shape (n_sections, 2) --> (n_sections, ..., 2, ...)
    zi_shape = [1] * x.ndim
    zi_shape[axis] = 2
    zi.shape = [n_sections] + zi_shape
    x_0 = axis_slice(ext, stop=1, axis=axis)
    (y, zf) = sosfilt(sos, ext, axis=axis, zi=zi * x_0)
    y_0 = axis_slice(y, start=-1, axis=axis)
    (y, zf) = sosfilt(sos, axis_reverse(y, axis=axis), axis=axis, zi=zi * y_0)
    y = axis_reverse(y, axis=axis)
    if edge > 0:
        y = axis_slice(y, start=edge, stop=-edge, axis=axis)
    return y
    def __init__(self, output_count, dtype, color):
        """
        Parameters
        ----------
        output_count : int
            number of channels to generate
        dtype : str or numpy.dtype or type
            data type to generate
        color : str
            coloration of noise to generate
        """
        super().__init__(output_count=output_count, dtype=dtype)
        color = color.upper()

        # initialize constants
        self._GAIN_FACTOR = 20

        # fmt: off
        self._B_PINK = np.array(
            [0.049922035, -0.095993537, 0.050612699, -0.004408786],
            dtype=dtype)
        self._A_PINK = np.array([1, -2.494956002, 2.017265875, -0.522189400],
                                dtype=dtype)

        # self._SOS_EM = np.array(
        #     [[0.642831971019108, -1.285660835867235, 0.642829227003050, 1, -1.938240831077712,
        #       0.938773412606562],
        #      [1.026898172219099, -1.958636131748593, 0.936872530871250, 1, -1.998609306155938,
        #       0.998614711257974],
        #      [1.016158122207582, - 2, 0.984300636295923, 1, -1.869587739664868,
        #       0.874716652825408],
        #      [2, 0.311116397866933, 0.032678524446224, 1, -1.199612825442955, 0.209446907520534],
        #      [0.025112532019265, -0.022708413470341, 0.002848575141982, 1, -0.041478045261128,
        #       0.084631884608049]],
        #     dtype=dtype
        # )  # according to Fig. 5a (+30 dB gain) in [2]
        self._SOS_EM = np.array(
            [[
                1.005691483788964, -2, 0.994309737062589, 1,
                -1.995131712857242, 0.995141919519996
            ],
             [
                 0.053750499045410, -0.070551854851138, 0.018848551742250, 1,
                 -1.396203288062881, 0.421364316283091
             ],
             [
                 1.000002228996281, -2, 0.999998915874074, 1,
                 -1.999189537733665, 0.999193631613767
             ],
             [
                 2, -0.050188084384710, 0.000794285722317, 1,
                 -0.173018215643467, 0.053670187934911
             ],
             [
                 1.383915599293111, 1.640929515677082, 0.383739206962885, 1,
                 1.074403226778343, 0.393608533940758
             ]],
            dtype=dtype)  # according to Fig. 5b (-10 dB gain) in [2]
        # fmt: on

        # pick utilized coefficients
        if color == "PINK":
            a = self._A_PINK
            self._sos = tf2sos(b=self._B_PINK, a=self._A_PINK).astype(dtype)
            # "It is generally discouraged to convert from TF to SOS format, since doing so
            # usually will not improve numerical precision errors."
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2sos.html
        elif color in ["EM", "EIGENMIKE"]:
            [_, a] = sos2tf(sos=self._SOS_EM)
            self._sos = self._SOS_EM
            # "Consider designing filters in ZPK format and converting directly to SOS."
            # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.tf2sos.html
        else:
            raise NotImplementedError(
                f'chosen noise generator color "{color}" not implemented yet.')

        # initialize IIR filter delay conditions
        self._last_delays = sosfilt_zi(sos=self._sos).astype(dtype)
        # adjust in middle axis according to output count
        self._last_delays = np.repeat(self._last_delays[:, np.newaxis, :],
                                      repeats=output_count,
                                      axis=1)

        # approximate decay time to skip transient response part of IIR filter, according to [1]
        t60_samples = int(np.log(1000.0) /
                          (1.0 - np.abs(np.roots(a)).max())) + 1
        # generate and discard long enough sequence to skip IIR transient response (decay time)
        _ = self.generate_block(block_length=t60_samples)
Exemple #19
0
    def start(self):

        hdr_input = None
        start = time.time()
        while hdr_input is None:
            self.monitor.info("Waiting for data to arrive.")
            if (time.time() - start) > self.patch.getfloat(
                    'fieldtrip', 'timeout', default=30):
                raise RuntimeError("Timeout while waiting for data.")
            time.sleep(0.1)
            hdr_input = self.ft_input.getHeader()

        self.monitor.info("Data arrived.")

        self.channel = self.patch.getint("input", "channel")
        self.key_biofeedback = self.patch.getstring("output",
                                                    "key_biofeedback")
        sfreq = hdr_input.fSample

        self.stride = self.patch.getint("input", "stride")

        self.window_biofeedback = self.patch.getint(
            "input", "window_biofeedback")  # seconds
        if self.stride >= self.window_biofeedback:
            raise RuntimeError(
                "stride must be shorter than window_biofeedback.")
        self.window_biofeedback = int(
            np.ceil(self.window_biofeedback /
                    self.stride))  # blocks of size stride

        self.window_target = self.patch.getint("input",
                                               "window_target")  # seconds
        if self.stride >= self.window_target:
            raise RuntimeError("stride must be shorter than window_target.")
        self.window_target = int(np.ceil(self.window_target /
                                         self.stride))  # blocks of size stride

        if self.window_biofeedback >= self.window_target:
            raise RuntimeError(
                "window_biofeedback must be shorter than window_target.")

        self.buffer = np.zeros(self.window_target)

        self.stride = int(np.ceil(self.stride *
                                  sfreq))  # convert to samples for indexing

        if hdr_input.nSamples < self.stride:
            self.begsample = 0
            self.endsample = self.stride - 1
        else:
            self.begsample = hdr_input.nSamples - self.stride
            self.endsample = hdr_input.nSamples - 1

        # Initialize filters (hardcode frequencies to prevent accidental changes
        # in inifile, and express them as bpm / 60 for readability)
        self.sos_hp = bessel_highpass(15 / 60, sfreq, 4)
        self.zi_hp = sosfilt_zi(self.sos_hp)

        self.sos_bp = bessel_bandpass(4 / 60, 12 / 60, sfreq, 2)
        self.zi_bp = sosfilt_zi(self.sos_bp)

        # self.previoustime = time.time()    # use to debug/monitor timing of calls to compute_biofeedback()

        while True:
            self.monitor.loop()
            self.compute_biofeedback()
Exemple #20
0
if __name__ == "__main__":
    #make filter
    lfp_freq = 1500.0
    filt_len = 0.01  #10ms
    filt_order = int(filt_len * lfp_freq)
    FILT_BUF_LEN = filt_order
    filt_low = 150.0
    filt_high = 250.0
    sos = signal.butter(filt_order, (filt_low, filt_high),
                        btype='bandpass',
                        analog=False,
                        output='sos',
                        fs=lfp_freq)
    #z1 is keeps track of filter state, so can pass little buffers of a whole signal and the result will be the same as filtering the whole thing in one go
    z1 = signal.sosfilt_zi(sos)

    #init other stuff
    #TODO change this to just the tetrodes we want to record ripple power from.
    trodes = [str(i + 1) for i in range(40)]
    #TODO change this to just tetrodes with hippocampal units on them
    spk_trodes = [i + ",0" for i in trodes]
    z1 = np.tile(np.reshape(z1, (filt_order, 1, 2)), (1, len(trodes), 1))

    bigbuf = np.zeros((len(trodes), FILT_BUF_LEN))
    bb_idx = 0

    timestamp = 0
    zcnt = 0
    REFRAC_PERIOD = 400  #ms
    refrac_end = 0
 def GetInitialFilterState(self):
     initialFilterState = signal.sosfilt_zi(self.secondOrderSections)
     return initialFilterState
Exemple #22
0
def test_sosfilt_zi():
    sos_f32 = np.array([[4, 5, 6, 1, 2, 3]], dtype=np.float32)
    assert_(sosfilt_zi(sos_f32).dtype == np.float32)
Exemple #23
0
 def butter_bandpass_filter_once(self,data, lowcut, highcut, fs, order=5):
     sos = self.butter_bandpass(lowcut, highcut, fs, order=order)
     # Apply the filter to data. Use lfilter_zi to choose the initial condition of the filter.
     zi = sosfilt_zi(sos)
     z, _ = sosfilt(sos, data, zi=zi * data[0])
     return sos, z, zi
Exemple #24
0
def getRippleStatistics(tetrodes, analysis_time=4, show_ripples=False, \
        ripple_statistics=None, interrupt_ripples=False):
    """
    Get ripple data statistics for a particular tetrode and a user defined time
    period.
    Added: 2019/02/19
    Archit Gupta

    :tetrodes: Indices of tetrodes that should be used for collecting the
        statistics.
    :analysis_time: Amount of time (specified in seconds) for which the data
        should be analyzed to get ripple statistics.
    :show_ripple: Show ripple as they happen in real time.
    :ripple_statistics: Mean and STD for declaring something a sharp-wave
        ripple.
    :returns: Distribution of ripple power, ripple amplitude and frequency
    """

    if show_ripples:
        plt.ion()

    if interrupt_ripples:
        ser = SerialPort.BiphasicPort()
    n_tetrodes = len(tetrodes)
    report_ripples = (ripple_statistics is not None)

    # Create a ripple filter (discrete butterworth filter with cutoff
    # frequencies set at Ripple LO and HI cutoffs.)
    ripple_filter = signal.butter(RiD.LFP_FILTER_ORDER, \
            (RiD.RIPPLE_LO_FREQ, RiD.RIPPLE_HI_FREQ), \
            btype='bandpass', analog=False, output='sos', \
            fs=RiD.LFP_FREQUENCY)

    # Filter the contents of the signal frame by frame
    ripple_frame_filter = signal.sosfilt_zi(ripple_filter)

    # Tile it to take in all the tetrodes at once
    ripple_frame_filter = np.tile(np.reshape(ripple_frame_filter, \
            (RiD.LFP_FILTER_ORDER, 1, 2)), (1, n_tetrodes, 1))

    # Initialize a new client
    client = TrodesInterface.SGClient("RippleAnalyst")
    if (client.initialize() != 0):
        del client
        raise Exception("Could not initialize connection! Aborting.")

    # Access the LFP stream and create a buffer for trodes to fill LFP data into
    lfp_stream = client.subscribeLFPData(
        TrodesInterface.LFP_SUBSCRIPTION_ATTRIBUTE, tetrodes)
    lfp_stream.initialize()

    # LFP Sampling frequency TIMES desired analysis time period
    N_DATA_SAMPLES = int(analysis_time * RiD.LFP_FREQUENCY)

    # Each LFP frame (I think it is just a single time point) is returned in
    # lfp_frame_buffer. The entire timeseries is stored in raw_lfp_buffer.
    lfp_frame_buffer = lfp_stream.create_numpy_array()
    ripple_filtered_lfp = np.zeros((n_tetrodes, N_DATA_SAMPLES), dtype='float')
    raw_lfp_buffer = np.zeros((n_tetrodes, N_DATA_SAMPLES), dtype='float')
    ripple_power = np.zeros((n_tetrodes, N_DATA_SAMPLES), dtype='float')

    # Create a plot to look at the raw lfp data
    timestamps = np.linspace(0, analysis_time, N_DATA_SAMPLES)
    iter_idx = 0
    prev_ripple = -1.0
    prev_interrupt = -1.0

    # Data to be logged for later use
    ripple_events = []
    trodes_timestamps = []
    wall_ripple_times = []
    interrupt_events = []
    if report_ripples:
        print('Using pre-recorded ripple statistics')
        print('Mean: %.2f' % ripple_statistics[0])
        print('Std: %.2f' % ripple_statistics[1])

    if show_ripples:
        interruption_fig = plt.figure()
        interruption_axes = plt.axes()
        plt.plot([], [])
        plt.grid(True)
        plt.ion()
        plt.show()

    wait_for_user_input = input("Press Enter to start!")
    start_time = 0.0
    start_wall_time = time.perf_counter()
    interruption_iter = -1
    is_first_ripple = True
    while (iter_idx < N_DATA_SAMPLES):
        n_lfp_frames = lfp_stream.available(0)
        for frame_idx in range(n_lfp_frames):
            # print("t__%.2f"%(float(iter_idx)/float(RiD.LFP_FREQUENCY)))
            t_stamp = lfp_stream.getData()
            trodes_time_stamp = client.latestTrodesTimestamp()
            raw_lfp_buffer[:, iter_idx] = lfp_frame_buffer[:]

            # If we have enough data to fill in a new filter buffer, filter the
            # new data
            if (iter_idx > RiD.RIPPLE_SMOOTHING_WINDOW) and (
                    iter_idx % RiD.LFP_FILTER_ORDER == 0):
                lfp_frame = raw_lfp_buffer[:, iter_idx -
                                           RiD.LFP_FILTER_ORDER:iter_idx]
                # print(lfp_frame)
                filtered_frame, ripple_frame_filter = signal.sosfilt(ripple_filter, \
                       lfp_frame, axis=1, zi=ripple_frame_filter)
                # print(filtered_frame)
                ripple_filtered_lfp[:, iter_idx - RiD.
                                    LFP_FILTER_ORDER:iter_idx] = filtered_frame

                # Averaging over a longer window to be able to pick out ripples effectively.
                # TODO: Ripple power is only being reported for each frame
                # right now: Filling out the same value for the entire frame.
                frame_ripple_power = np.sqrt(np.mean(np.power( \
                        ripple_filtered_lfp[:,iter_idx-RiD.RIPPLE_SMOOTHING_WINDOW:iter_idx], 2), axis=1))
                ripple_power[:,iter_idx-RiD.LFP_FILTER_ORDER:iter_idx] = \
                        np.tile(np.reshape(frame_ripple_power, (n_tetrodes, 1)), (1, RiD.LFP_FILTER_ORDER))
                if report_ripples:
                    if is_first_ripple:
                        is_first_ripple = False
                    else:
                        # Show the previous interruption after a sufficient time has elapsed
                        if show_ripples:
                            if (iter_idx == int(
                                (prev_ripple + RiD.INTERRUPTION_WINDOW) *
                                    RiD.LFP_FREQUENCY)):
                                data_begin_idx = int(
                                    max(0,
                                        iter_idx - 2 * RiD.INTERRUPTION_TPTS))
                                interruption_axes.clear()
                                interruption_axes.plot(timestamps[data_begin_idx:iter_idx], raw_lfp_buffer[0, \
                                        data_begin_idx:iter_idx])
                                interruption_axes.scatter(prev_ripple,
                                                          0,
                                                          c="r")
                                interruption_axes.set_ylim(-3000, 3000)
                                plt.grid(True)
                                plt.draw()
                                plt.pause(0.001)
                                # print(raw_lfp_buffer[0, data_begin_idx:iter_idx])

                        # If any of the tetrodes has a ripple, let's call it a ripple for now
                        ripple_to_baseline_ratio = (frame_ripple_power[0] - ripple_statistics[0])/ \
                                ripple_statistics[1]
                        if (ripple_to_baseline_ratio >
                                RiD.RIPPLE_POWER_THRESHOLD):
                            current_time = float(iter_idx) / float(
                                RiD.LFP_FREQUENCY)
                            if ((current_time - prev_ripple) >
                                    RiD.RIPPLE_REFRACTORY_PERIOD):
                                prev_ripple = current_time
                                current_wall_time = time.perf_counter(
                                ) - start_wall_time
                                time_lag = (current_wall_time - current_time)
                                if interrupt_ripples:
                                    ser.sendBiphasicPulse()
                                print(
                                    "Ripple @ %.2f, Real Time %.2f [Lag: %.2f], strength: %.1f"
                                    % (current_time, current_wall_time,
                                       time_lag, ripple_to_baseline_ratio))
                                trodes_timestamps.append(trodes_time_stamp)
                                ripple_events.append(current_time)
                                wall_ripple_times.append(current_wall_time)

            iter_idx += 1
            if (iter_idx >= N_DATA_SAMPLES):
                break

    if client is not None:
        client.closeConnections()

    print("Collected raw LFP Data. Visualizing.")
    power_mean, power_std = Visualization.visualizeLFP(timestamps, raw_lfp_buffer, ripple_power, \
            ripple_filtered_lfp, ripple_events, do_animation=False)
    if report_ripples:
        writeLogFile(trodes_timestamps, ripple_events, wall_ripple_times,
                     interrupt_events)

    # Program exits with a segmentation fault! Can't help this.
    wait_for_user_input = input('Press ENTER to quit')
    return (power_mean, power_std)
Exemple #25
0
 def get_zi(self):
     return sig.sosfilt_zi(self.filter)
Exemple #26
0
 w = np.linspace(0,1,len(h))
 plt.semilogy(w,h)
 plt.title("Frequency Response")
 ax = plt.subplot(122)
 w,grp = sig.group_delay(sig.sos2tf(sos))
 w = np.linspace(0,1,len(w))
 ax.plot(w,grp)
 plt.title("Group delay")
 plt.show()
 
 x = input("Try this filter? y/n (y) ") 
 if len(x)!=0 and x[0]=='n':
   continue #pick a new filter
 
 #Plot before/after for time and frequency domain
 zi = sig.sosfilt_zi(sos) # Set initial conditions
 # Use the initial conditions to produce a sane result (not jumping from 0)
 clean, zo = sig.sosfilt(sos, data, zi=zi*data[0])
 w = np.linspace(0, nyquist, len(f))
 plt.subplot(121)
 plt.plot(data, label="Unfiltered")
 bottom,top = plt.ylim()
 plt.title("Signal")
 # Despite the initial conditions, this still starts near 0
 plt.plot(clean, label="Filtered")
 plt.ylim(bottom,top)
 plt.legend(loc="best")
 plt.subplot(122)
 plt.semilogy(w,np.absolute(f), label="Unfiltered")
 g = fft.rfft(clean)
 plt.semilogy(w,np.absolute(g), label="Filtered")
Exemple #27
0
    def run(self):
        """
        Start thread execution

        :t_max: Max amount of hardware time (measured by Trodes timestamps)
            that ripple analysis should work for.
        :returns: Nothing
        """
        # Filter the contents of the signal frame by frame
        ripple_frame_filter = signal.sosfilt_zi(self._ripple_filter)

        # Tile it to take in all the tetrodes at once
        ripple_frame_filter = np.tile(np.reshape(ripple_frame_filter, \
                (RiD.LFP_FILTER_ORDER, 1, 2)), (1, self._n_tetrodes, 1))
        # Buffers for storing/manipulating raw LFP, ripple filtered LFP and
        # ripple power.
        raw_lfp_window = np.zeros((self._n_tetrodes, RiD.LFP_FILTER_ORDER),
                                  dtype='float')
        previous_mean_ripple_power = np.zeros_like(self._mean_ripple_power)
        previous_inst_ripple_power = np.zeros((self._n_tetrodes, ))
        lfp_window_ptr = 0

        # Delay measures for ripple detection (and trigger)
        ripple_unseen_LFP = False
        ripple_unseen_calib = False
        prev_ripple = -np.Inf
        curr_time = 0.0
        start_wall_time = time.perf_counter()
        curr_wall_time = start_wall_time

        # Keep track of the total time for which nothing was received
        down_time = 0.0
        while not self.req_stop():
            # Acquire buffered LFP frames and fill them in a filter buffer
            if self._lfp_consumer.poll():
                # print(MODULE_IDENTIFIER + "LFP Frame received for filtering.")
                (timestamp, current_lfp_frame,
                 frame_time) = self._lfp_consumer.recv()
                #print(timestamp)
                raw_lfp_window[:, lfp_window_ptr] = current_lfp_frame
                self._local_lfp_buffer.append(current_lfp_frame)
                lfp_window_ptr += 1
                down_time = 0.0

                # If the filter window is full, filter the data and record it in rippple power
                if (lfp_window_ptr == RiD.LFP_FILTER_ORDER):
                    self._ripple_data_access.acquire()
                    lfp_window_ptr = 0
                    filtered_window, ripple_frame_filter = signal.sosfilt(self._ripple_filter, \
                           raw_lfp_window, axis=1, zi=ripple_frame_filter)
                    current_ripple_power = np.sqrt(np.mean(np.power(filtered_window, 2), axis=1)) + \
                            (RiD.RIPPLE_SMOOTHING_FACTOR * previous_inst_ripple_power)
                    baseline_ripple_power = current_ripple_power[
                        self._ripple_baseline_tetrode.value]
                    current_ripple_power -= baseline_ripple_power
                    power_to_baseline_ratio = np.divide((current_ripple_power - self._mean_ripple_power), \
                            self._std_ripple_power)
                    previous_inst_ripple_power = current_ripple_power

                    # Fill in the shared data variables
                    self._local_ripple_power_buffer.append(
                        power_to_baseline_ratio)

                    # Timestamp has both trodes and system timestamps!
                    curr_time = float(timestamp) / RiD.SPIKE_SAMPLING_FREQ
                    logging.debug(MODULE_IDENTIFIER + "Frame @ %d filtered, mean ripple strength %.2f"%\
                            (timestamp, np.mean(power_to_baseline_ratio)))
                    """
                    if self._ripple_baseline_tetrode.value > 0:
                        # Get the ripple power on this tetrode to be used as baseline power
                        # power_to_baseline_ratio -= power_to_baseline_ratio[self._ripple_baseline_tetrode.value]
                        power_to_baseline_ratio -= 0
                    """

                    if ((curr_time - prev_ripple) >
                            RiD.RIPPLE_REFRACTORY_PERIOD):
                        # if (power_to_baseline_ratio > RiD.RIPPLE_POWER_THRESHOLD).any():
                        if power_to_baseline_ratio[
                                self._ripple_reference_tetrode.
                                value] > RiD.RIPPLE_POWER_THRESHOLD:
                            prev_ripple = curr_time
                            with self._trigger_condition:
                                # First trigger interruption and all time critical operations
                                self._trigger_condition.notify()
                            curr_wall_time = time.perf_counter()
                            ripple_unseen_LFP = True
                            ripple_unseen_calib = True
                            logging.info(
                                MODULE_IDENTIFIER +
                                "Detected ripple, notified with lag of %.6fs" %
                                (curr_wall_time - frame_time))
                            logging.info(MODULE_IDENTIFIER + "Detected ripple at %.6f, TS: %d. Peak Strength: %.2f"% \
                                    (frame_time, timestamp, power_to_baseline_ratio[self._ripple_reference_tetrode.value]))
                            if PRINT_DETECTION_MESSAGES:
                                print(MODULE_IDENTIFIER + "Detected ripple at %.6f, TS: %d. Peak Strength: %.2f"% \
                                        (frame_time, timestamp, power_to_baseline_ratio[self._ripple_reference_tetrode.value]))

                    # For each tetrode, update the MEAN and STD for ripple power
                    if (self._update_ripple_stats.value) and (
                            self._n_data_pts_seen.value <
                            RiD.STAT_ADJUSTMENT_DATA_PTS):
                        self._n_data_pts_seen.value += 1
                        np.copyto(previous_mean_ripple_power,
                                  self._mean_ripple_power)
                        self._mean_ripple_power += (current_ripple_power - previous_mean_ripple_power)/\
                                self._n_data_pts_seen.value
                        self._var_ripple_power += (current_ripple_power - previous_mean_ripple_power) * \
                                (current_ripple_power - self._mean_ripple_power)
                        np.sqrt(self._var_ripple_power /
                                self._n_data_pts_seen.value,
                                out=self._std_ripple_power)
                        # Print out stats every 5s
                        if self._n_data_pts_seen.value % int(
                                1 * RiD.LFP_FREQUENCY) == 0:
                            logging.info(MODULE_IDENTIFIER + "T%s: Mean LFP %.4f, STD LFP: %.4f"%(self._target_tetrodes[self._ripple_reference_tetrode.value],\
                                    self._mean_ripple_power[self._ripple_reference_tetrode.value], self._std_ripple_power[self._ripple_reference_tetrode.value]))
                    self._ripple_data_access.release()

                    if ((curr_time - prev_ripple) >
                            RiD.LFP_BUFFER_TIME / 2) and ripple_unseen_LFP:
                        ripple_unseen_LFP = False
                        # Copy data over for visualization
                        if len(self._local_lfp_buffer
                               ) == RiD.LFP_BUFFER_LENGTH:
                            with self._show_trigger:
                                np.copyto(self._raw_lfp_buffer,
                                          np.asarray(self._local_lfp_buffer).T)
                                np.copyto(
                                    self._ripple_power_buffer,
                                    np.asarray(
                                        self._local_ripple_power_buffer).T)
                                self._show_trigger.notify()
                            # logging.debug(MODULE_IDENTIFIER + "%.2fs: Peak ripple power in frame %.2f"%(curr_time, np.max(self._ripple_power_buffer)))
                            if __debug__:
                                print(
                                    MODULE_IDENTIFIER +
                                    "%.2fs: Peak ripple power in frame %.2f" %
                                    (curr_time,
                                     np.max(self._ripple_power_buffer)))

                    if ((curr_time - prev_ripple) > RiD.CALIB_PLOT_BUFFER_TIME
                            / 2) and ripple_unseen_calib:
                        ripple_unseen_calib = False
                        if self._calib_plot is not None:
                            with self._calib_trigger_condition:
                                self._calib_plot.update_shared_buffer(
                                    timestamp)
                                self._calib_trigger_condition.notify()
            else:
                # logging.debug(MODULE_IDENTIFIER + "No LFP Frames to process. Sleeping")
                time.sleep(0.005)
                down_time += 0.005
                if down_time > 1.0:
                    print(MODULE_IDENTIFIER +
                          "Warning: Not receiving LFP Packets.")
                    down_time = 0.0
def filter_signal(params, linear_data, process_type=None, debug_dict={}):
    """
    Filter linear_data and return new array with separate orders.

    I.e. if input data look like:
    linear_data = [num_alines, num_samples]

    output array looks like:
    aline_orders = [num_alines, num_orders, num_samples]

    Please note that we use here internal plotting functions that can be enabled if required for analysis.
    To enable plotting find the function in the code and set do_plot = True
    mp.plot_filter_response_each(h,do_plot=False,do_wait=True)
    mp.plot_filter_response_vs_signal_in(filter_response, linear_fw_data, do_plot=False, do_wait=True)
    mp.plot_signal_filtered_each(aline_fw, fsigs, filter_response, do_plot=False  , do_wait=True)

    :param params:
    :param linear_data:
    :param process_type: string 'fw' or 'rv' ... forward and reverse respectively.
    :return:
    """
    assert type(
        process_type) is str, 'Please set process_typ to \'fw\' or \'rv\'!'
    print('filtering ({}):'.format(process_type))

    peaks_str = params['auto_params']['peaks_' + process_type +
                                      '_px'].strip('[]')
    correction_factor = params['manual_params']['filter'][
        'CF_correction_factor']
    peaks = np.fromstring(peaks_str, sep=',') * correction_factor
    resample_factor = 1  #params['resample_to']/params['segment_len_before_resample_'+process_type]
    peaks = np.round(peaks * resample_factor).astype(np.int)
    # as a site note eval would have worked as well but
    # reading the string explicitly is safer and freedom of mind.
    # “Advanced Iterators - Dive Into Python 3.” https://diveintopython3.net/advanced-iterators.html#eval (accessed Jul. 22, 2020).

    half_passbandwidth_px = params['manual_params']['filter'][
        'half_passbandwidth_px'] * resample_factor  # 60 sample units
    half_stopbandwidth_px = params['manual_params']['filter'][
        'half_stopbandwidth_px'] * resample_factor  # 50 sample units
    HPW = half_passbandwidth_px * resample_factor
    HSW = half_stopbandwidth_px * resample_factor
    bandwidths = []
    # iirdesign assumes that the highest frequency based on the sample number is 1.0 and
    # the lowest frequency is 0.0.
    # We call those filter frequency units
    # Create pass-band (WP) and stop-band (WS) values in sample units
    # Note that we multiply by TWO to ...
    segment_len = linear_data.shape[1]
    HPW /= segment_len
    HSW /= segment_len
    use_num_orders = params['manual_params']['filter']['use_num_orders']
    if len(peaks) < use_num_orders:
        warn('Parameter use_num_orders = {} is larger than peak len {}'.format(
            use_num_orders, len(peaks)))
        use_num_orders = len(peaks) - 1
    for peak in peaks[0:use_num_orders]:
        peak /= segment_len
        peak *= 1.00
        # TODO Peak correction seems not improve things at the moment. So peak *= 1.00 may be removed sometimes.
        freq_dep_factor = 1  # frequency dependent factor
        bandwidths.append({
            'WP':
            np.array([peak - HPW, peak + HPW]) / freq_dep_factor,
            'WS':
            np.array([peak - HPW - HSW, peak + HPW + HSW]) / freq_dep_factor
        })

    # [print(bw) for bw in bandwidths]
    # [print(np.diff(bw['WS'])) for bw in bandwidths]

    # Design filters for each peak
    filters = []
    for BW, CP in zip(bandwidths, peaks):
        WP = BW['WP']
        WS = BW['WS']
        # print(WP, WS)
        sosdata = sg.iirdesign(
            wp=WP,
            ws=WS,
            gpass=params['manual_params']["filter"]['gpass'],
            gstop=params['manual_params']["filter"]['gstop'],
            ftype=params['manual_params']["filter"]['type'],
            output='sos',
            analog=False)
        # Compute initial conditions
        zi = sg.sosfilt_zi(sosdata)
        filters.append({'CP_fw': CP, 'sosdata': sosdata, 'zi': zi})

    if debug_dict.get('do_plot_filter_response'):
        mp.plot_filter_response_vs_signal_in(peaks, filters, linear_data,
                                             debug_dict)

    aline_orders = []
    print('{: 4d}'.format(0), end='')

    if debug_dict.get('do_plot_filtered_sig'):
        start_at = debug_dict.get('start_at')
        if linear_data.shape[0] < start_at:
            start_at = linear_data.shape[0] // 4
    else:
        start_at = 0

    # Apply filter to all Segments
    for n, segment in enumerate(linear_data[start_at:]):
        fsigs = []  # filtered signals according to number of peaks_fw
        for flt in filters:
            sosdata = flt['sosdata']
            zi = flt['zi']
            if params['manual_params']['filter']['use_initial_sos_cond']:
                fsig, _ = sg.sosfilt(sosdata, segment, zi=zi)
            else:
                fsig = sg.sosfiltfilt(sosdata, segment)
            fsigs.append(fsig)

        # Collect filtered signals
        aline_orders.append(fsigs)

        if np.mod(n, 10) == 0: print('\b\b\b\b{: 4d}'.format(n), end='')

        if debug_dict.get('do_plot_filtered_sig'):
            axes = mp.plot_signal_filtered_each(peaks, segment, fsigs, filters,
                                                debug_dict)
            if axes is None:
                debug_dict['do_plot_filtered_sig'] = False
                debug_dict['do_plot_filtered_sig_wait'] = False

    print('\nfinish filtering')

    json.dump(params, open(params['auto_params']['json_fname'], 'w'), indent=2)
    return params, np.array(aline_orders)
Exemple #29
0
    def run(self):
        """
        Start thread execution

        :t_max: Max amount of hardware time (measured by Trodes timestamps)
            that ripple analysis should work for.
        :returns: Nothing
        """
        # Filter the contents of the signal frame by frame
        ripple_frame_filter = signal.sosfilt_zi(self._ripple_filter)

        # Tile it to take in all the tetrodes at once
        ripple_frame_filter = np.tile(np.reshape(ripple_frame_filter, \
                (RiD.LFP_FILTER_ORDER, 1, 2)), (1, self._n_tetrodes, 1))
        # Buffers for storing/manipulating raw LFP, ripple filtered LFP and
        # ripple power.
        raw_lfp_window = np.zeros((self._n_tetrodes, RiD.LFP_FILTER_ORDER),
                                  dtype='float')
        ripple_power = collections.deque(maxlen=RiD.RIPPLE_SMOOTHING_WINDOW)
        previous_mean_ripple_power = np.zeros_like(self._mean_ripple_power)
        lfp_window_ptr = 0
        pow_window_ptr = 0
        n_data_pts_seen = 0

        # Delay measures for ripple detection (and trigger)
        ripple_unseen = False
        prev_ripple = -np.Inf
        curr_time = 0.0
        start_wall_time = time.time()
        curr_wall_time = start_wall_time

        # Keep track of the total time for which nothing was received
        down_time = 0.0
        while not self.req_stop():
            # Acquire buffered LFP frames and fill them in a filter buffer
            if self._lfp_consumer.poll():
                # print(MODULE_IDENTIFIER + "LFP Frame received for filtering.")
                (timestamp, current_lfp_frame,
                 frame_time) = self._lfp_consumer.recv()
                #print(timestamp)
                raw_lfp_window[:, lfp_window_ptr] = current_lfp_frame
                self._local_lfp_buffer.append(current_lfp_frame)
                lfp_window_ptr += 1
                down_time = 0.0

                # If the filter window is full, filter the data and record it in rippple power
                if (lfp_window_ptr == RiD.LFP_FILTER_ORDER):
                    lfp_window_ptr = 0
                    filtered_window, ripple_frame_filter = signal.sosfilt(self._ripple_filter, \
                           raw_lfp_window, axis=1, zi=ripple_frame_filter)
                    current_ripple_power = np.sqrt(
                        np.mean(np.power(filtered_window, 2), axis=1))
                    ripple_power.append(current_ripple_power)

                    # Fill in the shared data variables
                    self._local_ripple_power_buffer.append(
                        current_ripple_power)

                    # TODO: Enable this part of the code to update the mean and STD over time
                    # Update the mean and std for ripple power at each of the tetrodes
                    """
                    np.copyto(previous_mean_ripple_power, self._mean_ripple_power)
                    self._mean_ripple_power += (ripple_power[:, pow_window_ptr] - previous_mean_ripple_power)/n_data_pts_seen
                    self._std_ripple_power += (ripple_power[:, pow_window_ptr] - previous_mean_ripple_power) * \
                            (ripple_power[:, pow_window_ptr] - self._mean_ripple_power)
                    # This is the accumulate sum of squares. The actual variance is <current-value>/(n_data_pts_seen-1)
                    """
                    n_data_pts_seen += 1
                    # print("Read %d frames so far."%n_data_pts_seen)

                    # TODO: Right now, we are not using average power in the smoothing window, but the current power.
                    power_to_baseline_ratio = np.divide(
                        current_ripple_power - self._mean_ripple_power,
                        self._std_ripple_power)

                    # Timestamp has both trodes and system timestamps!
                    curr_time = float(timestamp) / RiD.SPIKE_SAMPLING_FREQ
                    logging.debug(
                        MODULE_IDENTIFIER +
                        "Frame @ %d filtered, mean ripple strength %.2f" %
                        (timestamp, np.mean(power_to_baseline_ratio)))
                    if ((curr_time - prev_ripple) >
                            RiD.RIPPLE_REFRACTORY_PERIOD):
                        # TODO: Consider switching to all, or atleast a majority of tetrodes for ripple detection.
                        if (power_to_baseline_ratio >
                                RiD.RIPPLE_POWER_THRESHOLD).any():
                            prev_ripple = curr_time
                            with self._trigger_condition:
                                # First trigger interruption and all time critical operations
                                self._trigger_condition.notify()
                                curr_wall_time = time.time()
                                logging.info(
                                    MODULE_IDENTIFIER +
                                    "Detected ripple, notified with lag of %.2f ms"
                                    % (curr_wall_time - frame_time))
                                ripple_unseen = True
                            logging.info(MODULE_IDENTIFIER + "Detected ripple at %.2f, TS: %d. Peak Strength: %.2f"% \
                                    (curr_time, timestamp, np.max(power_to_baseline_ratio)))
                    if ((curr_time - prev_ripple) >
                            RiD.LFP_BUFFER_TIME / 2) and ripple_unseen:
                        ripple_unseen = False
                        # Copy data over for visualization
                        if len(self._local_lfp_buffer
                               ) == RiD.LFP_BUFFER_LENGTH:
                            np.copyto(self._raw_lfp_buffer,
                                      np.asarray(self._local_lfp_buffer).T)
                            np.copyto(
                                self._ripple_power_buffer,
                                np.asarray(self._local_ripple_power_buffer).T)
                            logging.info(
                                MODULE_IDENTIFIER +
                                "%.2fs: Peak ripple power in frame %.2f" %
                                (curr_time, np.max(self._ripple_power_buffer)))
                            with self._show_trigger:
                                # First trigger interruption and all time critical operations
                                self._show_trigger.notify()
            else:
                # logging.debug(MODULE_IDENTIFIER + "No LFP Frames to process. Sleeping")
                time.sleep(0.005)
                down_time += 0.005
                if down_time > 1.0:
                    print(MODULE_IDENTIFIER +
                          "Warning: Not receiving LFP data.")
                    down_time = 0.0
Exemple #30
0

n = 101
t = np.linspace(0, 1, n)
np.random.seed(123)
x = 0.45 + 0.1*np.random.randn(n)

sos = butter(8, 0.125, output='sos')

# Filter using the default initial conditions.
y = sosfilt(sos, x)

# Filter using the state for which the output
# is the constant x[:4].mean() as the initial
# condition.
zi = x[:4].mean() * sosfilt_zi(sos)
y2, zo = sosfilt(sos, x, zi=zi)

# Plot everything.
plt.figure(figsize=(4.0, 2.8))
plt.plot(t, x, alpha=0.75, linewidth=1, label='x')
plt.plot(t, y, label='y  (zero ICs)')
plt.plot(t, y2, label='y2 (mean(x[:4]) ICs)')

plt.legend(framealpha=1, shadow=True)
plt.grid(alpha=0.25)
plt.xlabel('t')
plt.title('Filter with different '
          'initial conditions',
          fontsize=10)
plt.tight_layout()
Exemple #31
0
 def butter_lowpass_filter(self, data, highcut, order=5, init=0):
     sos = self.butter_lowpass(highcut, order=order)
     zi = sosfilt_zi(sos)
     y = sosfilt(sos, data, zi=init*zi)[0]
     return y
Exemple #32
0
 def __init__(self):
     self.zic = np.zeros((electrodeNum, (max(b.size, a.size) - 1)))
     # internal state for 60Hz comb filter
     self.zis = np.expand_dims(np.tile(sig.sosfilt_zi(sos), (8, 1)), axis=0)