Esempio n. 1
0
 def work(self):
     input_signal = self.multichannel_full[:, self.channels]
     filter_weights = self.trainer.output().read()
     filter_output = convolve_spatiotemporal(input_signal, filter_weights,
                                             self.delays)
     envelope = np.abs(filter_output)
     self.output().write(Signal(envelope, input_signal.fs))
Esempio n. 2
0
def hilbert_fast(sig: Signal):
    # Zero-pad signal to nearest power of 2 or 3 in order to speed up
    # computation
    exponents = ceil(log(sig.num_samples) / log([2, 3]))
    N = int(min(array([2, 3])**exponents))
    analytic = hilbert(sig, N)
    return Signal(analytic[:sig.num_samples], sig.fs)
Esempio n. 3
0
 def work(self):
     sig_in = self.requires().output().read()
     self.update_status("Read raw signal. Applying bandpass filter.")
     # We cannot directly use fklab's "compute_envelope", as this function
     # averages all channel envelopes into one.
     bpf_out = apply_filter(
         sig_in,
         axis=0,
         band=self.freq_band,
         fs=sig_in.fs,
         transition_width="20%",
         attenuation=30,
     )
     self.update_status(
         "Applied bandpass filter. Calculating raw envelope via Hilbert transform."
     )
     # Use padding to nearest power of 2 or 3 when calculating Hilbert
     # transform for great speedup (via FFT).
     N_orig = sig_in.shape[0]
     N = int(min(array([2, 3])**ceil(log(N_orig) / log([2, 3]))))
     envelope_raw_padded = abs(analytical(bpf_out, N=N, axis=0))
     del bpf_out
     envelope_raw = envelope_raw_padded[:N_orig, :]
     self.update_status("Calculated raw envelope. Smoothing envelope.")
     del envelope_raw_padded
     envelope_smooth = smooth1d(envelope_raw,
                                delta=1 / sig_in.fs,
                                kernel="gaussian",
                                bandwidth=4e-3)
     self.update_status("Smoothed envelope. Writing envelope to disk.")
     del envelope_raw
     sig_out = Signal(envelope_smooth, sig_in.fs, sig_in.units)
     del envelope_smooth
     self.output().write(sig_out)
     del sig_in, sig_out
Esempio n. 4
0
    def work(self):
        raw_recording = self.requires().output()
        fs_orig = raw_recording.fs
        q, remainder = divmod(fs_orig, config.fs_target)
        fs_new = fs_orig / q
        if remainder > 0:
            getLogger(__name__).warning(
                f"Original sampling rate of {self.file_ID} ({fs_orig} Hz) is"
                f" not an integer multiple of the target sampling rate"
                f" ({config.fs_target} Hz). Sampling rate after downsampling"
                f" will instead be {fs_new} Hz."
            )
        self.update_status(
            f"Decimating {self.file_ID} ({self.file_ID.path}) of size"
            f" {self.file_ID.path.stat().st_size / 1E9:.1f} GB by a factor {q}."
        )
        t_prev = time()

        def track_downsampling_progress(progress: float):
            nonlocal t_prev
            t_now = time()
            time_passed = t_now - t_prev
            if time_passed > 5:
                self.update_progress(progress, "Downsampling progress")
                t_prev = t_now

        signal_down = decimate_chunkwise(
            raw_recording.signal,
            factor=q,
            loop_callback=track_downsampling_progress,
        )
        signal_down *= raw_recording.to_microvolts
        raw_recording.close()
        self.output().write(Signal(signal_down, fs_new, "μV"))
Esempio n. 5
0
 def work(self):
     fs = self.input_signal.fs
     self.ripple_filter.order = self.order
     self.ripple_filter.fs = fs
     b, a = self.ripple_filter.tf
     filtered = lfilter(b, a, self.input_signal)
     envelope = abs(filtered)
     self.output().write(Signal(envelope, fs))
Esempio n. 6
0
 def read(self) -> Signal:
     log.info(f"Reading signal file at {self} ({self.size}) into memory.")
     with self.open_file_for_read() as f:
         dataset = f[self.KEY_SIG]
         array = dataset[:]
         fs = dataset.attrs[self.KEY_FS]
         units = dataset.attrs.get(self.KEY_UNITS, None)
     return Signal(array, fs, units)
Esempio n. 7
0
 def o_t(self):
     logger.info("Band-pass-filtering signal..")
     rm = self.reference_maker
     filter_output = apply_filter(self.x_t,
                                  rm.band,
                                  fs=self.fs,
                                  **rm.filter_options)
     logger.info("Done")
     return Signal(filter_output, self.fs)
Esempio n. 8
0
 def e_t(self):
     logger.info("Smoothing envelope..")
     unsmoothed = abs(self.analytic)
     rm = self.reference_maker
     envelope = smooth1d(self.envelope_unsmoothed,
                         delta=1 / self.fs,
                         **rm.smooth_options)
     logger.info("Done")
     return Signal(envelope, self.fs)
Esempio n. 9
0
 def ripple_envelope(self) -> Signal:
     """ Offline and ripple-only algorithm. """
     ripple_envelope = compute_envelope(
         self.ripple_channel,
         self.ripple_band,
         fs=self.ripple_channel.fs,
         filter_options=self.ripple_filter_options,
         smooth_options=self.ripple_smooth_options,
     )
     return Signal(ripple_envelope, self.ripple_channel.fs)
Esempio n. 10
0
def _concat_extracts(data: Signal, segs: Segment, max_duration=60) -> Signal:
    """
    :return: Concatenated array of `segs` extracted from `data`.

    :param data:
    :param segs:
    :param max_duration:  Resulting array will be no longer than this.
                (Goal: limit memory usage).
    """
    # Might wanna randomize segs order here.
    duration = 0
    extracts = []
    for extract in data.extract(segs):
        duration += extract.duration
        if duration > max_duration:
            break
        else:
            extracts.append(extract)
    catena = concatenate(extracts)
    return Signal(catena, data.fs)
Esempio n. 11
0
    def as_model_io(self, sig: Signal) -> TorchArray:
        """
        Convert a numpy array to a batched single-sample pytorch array,
        optionally moved to a GPU.

        input shape: (num_samples, num_channels)
        output shape: (1, num_samples, num_channels)
        """
        pytorch_array = numpy_to_torch(sig.as_matrix())
        batched = to_batch(pytorch_array, one_sample=True)
        return batched.to(device)
Esempio n. 12
0
 def work(self):
     signal = self.multichannel_train[:, self.channels]
     segments = self.reference_segs_train
     data = Signal(data=delay_stack(signal, self.delays), fs=signal.fs)
     reference = _concat_extracts(data, segments)
     background = _concat_extracts(data, segments.invert())
     log.info(f"Reference data length: {reference.duration} s")
     log.info(f"Background data length: {background.duration} s")
     # Columns = channels = variables. Rows are (time) samples.
     Rss = _as_matrix(cov(reference, rowvar=False))
     Rnn = _as_matrix(cov(background, rowvar=False))
     GEVals, GEVecs = eigh(Rss, Rnn)
     first_GEVec = GEVecs[:, argmax(GEVals)]
     self.output().write(first_GEVec)
Esempio n. 13
0
def calc_detections(envelope: Signal, threshold: float,
                    lockout_time: float) -> ndarray:
    """
    :param envelope:
    :param threshold:
    :param lockout_time: In seconds
    :return: Array of detection times, in seconds.
    """
    lockout_samples = time_to_index(lockout_time, envelope.fs)
    detection_ix = calc_detection_indices(envelope.astype(float),
                                          float(threshold),
                                          lockout_samples.astype(int))
    detections = detection_ix / envelope.fs
    return detections
Esempio n. 14
0
 def target_signal(self) -> BinarySignal:
     """ Binary target, or training signal, as a one-column matrix. """
     segs = self.reference_segs_train.scale(
         1 + config.reference_seg_extension, reference=1
     )
     N = self.reference_channel_train.shape[0]
     sig = np.zeros(N)
     # Convert segment times to a binary signal (in a one-column matrix) that
     # is as long as the full training input signal.
     if config.target_fullrect:
         self._add_rects(sig, segs)
     else:
         self._add_start_rects(sig, segs)
     # self._add_triangles(sig, segs)
     return Signal(sig, self.reference_channel_train.fs).as_matrix()
Esempio n. 15
0
 def work(self):
     envelope_chunks = []
     inputt = self.multichannel_full.as_matrix()[:, self.channels]
     # Full input signal is too big for GPU memory.
     # Thus: split input, pass through h, concatenate results
     num_chunks = inputt.duration // self.seconds_per_chunk
     input_chunks = array_split(inputt, num_chunks, axis=inputt.time_axis)
     model: RNN = self.model_selector.output().read()
     h = model.get_init_h()
     for i, input_chunk in enumerate(input_chunks):
         log.info(f"Transforming chunk {i} of {num_chunks}")
         with torch.no_grad():
             input_torch = self.as_model_io(input_chunk)
             output, h = model(input_torch, h)
             # Cannot use torch.nn.functional.sigmoid (deprecated).
             envelope: TorchArray = torch.sigmoid(output.squeeze())
             envelope_cpu = envelope.to("cpu")
             envelope_numpy = envelope_cpu.numpy()
         envelope_chunks.append(envelope_numpy)
     envelope_full = concatenate(envelope_chunks)
     sig = Signal(envelope_full, inputt.fs)
     self.output().write(sig)
Esempio n. 16
0
 def plot_input(self, ax, trange):
     LFP_data = stack([rm.sr_channel, rm.ripple_channel, rm.toppyr_channel],
                      axis=1)
     LFP = Signal(LFP_data, rm.sr_channel.fs)
     plot_sig(LFP, ax, trange)
     add_voltage_scalebar(ax)
Esempio n. 17
0
 def SW_LPF(self, signal: Signal) -> Signal:
     fn = signal.fs / 2
     order = 5
     ba = butter(order, self.SW_cutoff / fn, "low")
     out = filtfilt(*ba, signal)
     return Signal(out, signal.fs)
Esempio n. 18
0
def plot_signal(signal: Signal,
                time_range: Tuple[float, float],
                y_scale: float = 500,
                height: float = 0.5,
                channels: Optional[ndarray] = None,
                bottom_first: bool = True,
                tight_ylims: bool = False,
                zero_lines: bool = True,
                time_grid: bool = True,
                y_grid: Optional[bool] = None,
                color: Color = "black",
                lw: float = 0.9,
                ax: Optional[Axes] = None,
                **kwargs) -> (Figure, Axes):
    """
    Plot a time-slice of a single- or multichannel signal.

    When the signal is multichannel, each channel will be plotted with the same
    scale.

    :param time_range:  Time slice to plot. In seconds.
    :param channels:  Which channels to plot. Plots all channels by default.
    :param height:  Height of each channel, in inches. Only relevant when no
                Axes is given.
    :param y_scale:  How much data-y-units the visual vertical spacing between
                channels represents.
    :param bottom_first:  If True (default), the first channel will be
                plotted at the bottom of the figure.
    :param tight_ylims:  If True, adapts the ylims to tightly fit the data
                visible in `time_range`. If False (default), makes sure that
                plots of different time ranges of `signal` will all have the
                same ylims.
    :param zero_lines:  Whether to plot grey y==0 lines in each channel.
    :param time_grid:  Whether to plot vertical gridlines, with corresponding
                absolute time ticks and ticklabels.
    :param y_grid:  Whether to plot horizontal gridlines, with corresponding
                y-ticks and -ticklabels. By default, only plots a y-grid if the
                signal is single-channel and `zero_lines` is False.
    :param ax:  The axes to plot on. If None (default), creates a new figure and
                axes.
    :param kwargs:  Passed on to `ax.plot()`.
    """
    signal = signal.as_matrix()
    if ax is None:
        fig, ax = subplots(figsize=(12, height * signal.num_channels))
    else:
        fig = ax.get_figure()
    if y_grid is None:
        if signal.num_channels == 1 and not zero_lines:
            y_grid = True
        else:
            y_grid = False
    if channels is None:
        channels = arange(signal.num_channels)
    ix = time_to_index(time_range,
                       signal.fs,
                       arr_size=signal.num_samples,
                       clip=True)
    y: Signal = signal[slice(*ix), channels]
    t = y.get_time_vector(t0=time_range[0])
    if bottom_first:
        y_offsets = y_scale * arange(0, signal.num_channels)
    else:
        y_offsets = y_scale * arange(0, -signal.num_channels, -1)
    y_separated = y + y_offsets
    if zero_lines:
        ax.hlines(y_offsets, *time_range, colors="grey", lw=1)
    kwargs.update(dict(color=color, lw=lw))
    ax.plot(t, y_separated, **kwargs)
    ax.set_xlim(time_range)
    if not tight_ylims:
        ax.set_ylim(_get_global_ylims(signal, y_scale))
    if time_grid:
        ax.set_xlabel("Time (s)")
    else:
        ax.grid(False, which="x")
        ax.set_xticks([])
    if not y_grid:
        ax.grid(False, which="y")
        ax.set_yticks([])
    return (fig, ax)