Beispiel #1
0
    async def coro_calculate(self, calc_all: bool):
        """
        Coroutine to calculate all results.
        """
        self.is_calculating_all = calc_all

        self.mp_handler = MPHandler()
        self.mp_handler.stop()

        params = self.get_params(all_signals=calc_all)
        if params.transform == _wft:
            if self.view.get_fmin() is None:
                raise Exception("Minimum frequency must be defined for WFT.")

        self.is_plotted = False
        self.view.main_plot().clear()
        self.view.main_plot().set_in_progress(True)
        self.invalidate_data()

        log: bool = (params.transform == _wt)
        self.view.main_plot().set_log_scale(logarithmic=log)
        self.view.amplitude_plot().set_log_scale(logarithmic=log)

        self.view.on_calculate_started()

        all_data = await self.mp_handler.coro_transform(
            params, self.on_progress_updated)

        for d in all_data:
            self.on_transform_completed(*d)
Beispiel #2
0
    async def coro_preprocess_selected_signal(self) -> List[ndarray]:
        sig = self.get_selected_signal()

        if not self.preproc_mp_handler:
            self.preproc_mp_handler = MPHandler()

        return await self.preproc_mp_handler.coro_preprocess(sig, None, None)
Beispiel #3
0
    async def coro_calculate(self) -> None:
        self.view.enable_save_data(False)
        self.view.groupbox_stats.setEnabled(
            self.should_stats_be_enabled(calculating=True))

        if self.mp_handler:
            self.mp_handler.stop()

        self.is_plotted = False
        self.invalidate_data()

        self.view.main_plot().clear()
        self.view.main_plot().set_in_progress(True)
        self.view.amplitude_plot().clear()
        self.view.amplitude_plot().set_in_progress(True)

        params = self.get_params()
        self.params = params

        self.mp_handler = MPHandler()

        self.view.on_calculate_started()

        if "PYMODALIB_CACHE" not in os.environ:
            cache = Settings().get_pymodalib_cache()
            if cache and cache is not "None":
                os.environ["PYMODALIB_CACHE"] = cache

        sig1a, sig1b, sig2a, sig2b = self.signals.get_all()
        if sig2a is None:
            self.results = (await self.mp_handler.coro_group_coherence(
                sig1a,
                sig1b,
                fs=self.signals.frequency,
                on_progress=self.on_progress_updated,
                **params,
            ))[0]
        else:
            self.results = (await self.mp_handler.coro_dual_group_coherence(
                sig1a,
                sig1b,
                sig2a,
                sig2b,
                fs=self.signals.frequency,
                on_progress=self.on_progress_updated,
                **params,
            ))[0]

        self.view.groupbox_stats.setEnabled(
            self.should_stats_be_enabled(calculating=False))
        self.enable_save_data(True)

        self.update_plots()
        self.view.on_calculate_stopped()
        self.on_all_tasks_completed()
Beispiel #4
0
    async def coro_calculate(self):
        """Coroutine which calculates the bispectra."""
        if self.mp_handler:
            self.mp_handler.stop()

        self.is_plotted = False
        self.invalidate_data()

        # We need to use the same params for biphase later; save now because the UI could change.
        self.params = self.get_params()

        self.mp_handler = MPHandler()
        data = await self.mp_handler.coro_bispectrum_analysis(
            self.signals, self.params, self.on_progress_updated
        )

        for d in data:
            self.on_bispectrum_completed(*d)

        self.view.on_calculate_stopped()
        self.update_plots()
Beispiel #5
0
    async def coro_preprocess_selected_signal(self) -> List[ndarray]:
        """
        Coroutine to preprocess the currently selected signal.

        :return: the preprocessed signal as a 1D array
        """
        sig = self.get_selected_signal()
        fmin = self.view.get_fmin()
        fmax = self.view.get_fmax()

        if not self.preproc_mp_handler:
            self.preproc_mp_handler = MPHandler()

        return await self.preproc_mp_handler.coro_preprocess(sig, fmin, fmax)
Beispiel #6
0
    async def coro_plot_preprocessed_signal(self) -> None:
        """
        Coroutine to preprocess the signal and plot the result.
        """
        sig = self.get_selected_signal()
        fmin = self.view.get_fmin()
        fmax = self.view.get_fmax()

        if not self.preproc_mp_handler:
            self.preproc_mp_handler = MPHandler()

        result = await self.preproc_mp_handler.coro_preprocess(sig, fmin, fmax)

        if result and result[0] is not None:
            self.view.plot_preprocessed_signal(sig.times, sig.signal, result[0])
Beispiel #7
0
    async def coro_preprocess_all_signals(self) -> List[ndarray]:
        """
        Coroutine to preprocess all signals.

        :return: a list containing the preprocessed signals as a 1D array each
        """
        signals = self.signals

        try:
            fmin = self.view.get_fmin()
            fmax = self.view.get_fmax()
        except AttributeError:
            fmin, fmax = None, None

        if not self.preproc_mp_handler:
            self.preproc_mp_handler = MPHandler()

        return await self.preproc_mp_handler.coro_preprocess(signals, fmin, fmax)
Beispiel #8
0
    async def coro_calculate(self, calculate_all: bool) -> None:
        self.view.enable_save_data(False)
        self.is_calculating_all = calculate_all

        if self.mp_handler:
            self.mp_handler.stop()

        self.is_plotted = False
        self.invalidate_data()

        self.view.main_plot().clear()
        self.view.main_plot().set_in_progress(True)
        self.view.amplitude_plot().clear()
        self.view.amplitude_plot().set_in_progress(True)

        params = self.get_params(all_signals=calculate_all)
        self.params = params

        self.surrogate_count = self.view.get_surr_count()
        self.surrogates_enabled = self.view.get_surr_enabled()

        self.mp_handler = MPHandler()

        self.view.main_plot().set_log_scale(logarithmic=True)
        self.view.amplitude_plot().set_log_scale(logarithmic=True)

        self.view.on_calculate_started()
        data = await self.mp_handler.coro_transform(
            params=params, on_progress=self.on_progress_updated)

        for d in data:
            self.on_transform_completed(*d)

        all_data = await self.coro_phase_coherence(self.signals_calc, params,
                                                   self.on_progress_updated)

        for d in all_data:
            self.on_phase_coherence_completed(*d)

        self.plot_phase_coherence()
        self.view.on_calculate_stopped()
        self.on_all_tasks_completed()
        print("Finished calculating phase coherence.")
Beispiel #9
0
    async def coro_calculate(self):
        if self.mp_handler:
            self.mp_handler.stop()

        self.mp_handler = MPHandler()
        data = await self.mp_handler.coro_bayesian(self.signals,
                                                   self.get_paramsets(),
                                                   self.on_progress_updated)

        for d in data:
            self.on_bayesian_inference_completed(*d)

        if data:
            self.plot_bayesian()
        else:
            print("No data returned; are any parameter sets added?")

        print("Dynamical Bayesian inference completed.")
        self.view.on_calculate_stopped()
Beispiel #10
0
    async def coro_calculate(self, calculate_all: bool) -> None:
        """
        Coroutine to perform the calculation.

        :param calculate_all: whether to calculate for all signals, or just the current signal
        """
        self.is_calculating_all = calculate_all
        self.view.enable_save_data(False)
        for s in self.signals:
            s.data = None

        if self.mp_handler:
            self.mp_handler.stop()

        self.mp_handler = MPHandler()

        params = self.get_params(all_signals=calculate_all)
        self.params = params
        self.params.preprocess = self.view.get_preprocess()

        self.is_plotted = False
        self.view.main_plot().clear()
        self.invalidate_data()

        self.view.main_plot().set_log_scale(logarithmic=True)

        self.view.on_calculate_started()

        all_data = await self.mp_handler.coro_harmonics(
            self.signals_calc, params, self.params.preprocess,
            self.on_progress_updated)

        if not isinstance(all_data, List):
            all_data = [all_data]

        for signal, data in zip(self.signals_calc, all_data):
            signal.data = data

        self.view.enable_save_data(bool(all_data))
        self.view.on_calculate_stopped()
        self.update_plots()
Beispiel #11
0
class BAPresenter(BaseTFPresenter):
    def __init__(self, view):
        super().__init__(view)
        self.params: BAParams = None

        from gui.windows.bispectrum.BAWindow import BAWindow

        self.view: BAWindow = view

    def init(self):
        super().init()
        self.view.switch_to_dual_plot()

    def calculate(self, calculate_all: bool):
        """Starts the coroutine which calculates the data."""
        asyncio.ensure_future(self.coro_calculate())
        self.view.on_calculate_started()

    async def coro_calculate(self):
        """Coroutine which calculates the bispectra."""
        if self.mp_handler:
            self.mp_handler.stop()

        self.is_plotted = False
        self.invalidate_data()

        # We need to use the same params for biphase later; save now because the UI could change.
        self.params = self.get_params()

        self.mp_handler = MPHandler()
        data = await self.mp_handler.coro_bispectrum_analysis(
            self.signals, self.params, self.on_progress_updated
        )

        for d in data:
            self.on_bispectrum_completed(*d)

        self.view.on_calculate_stopped()
        self.update_plots()

    def add_point(self, x: float, y: float):
        asyncio.ensure_future(self.coro_biphase(x, y))
        self.view.on_calculate_started()

    async def coro_biphase(self, x: float, y: float):
        self.mp_handler.stop()

        fs = self.params.fs
        f0 = self.params.f0
        fr = self.view.get_selected_freq_pair()
        x, y = fr

        if x is not None and y is not None:
            data = await self.mp_handler.coro_biphase(
                self.signals, fs, f0, fr, self.on_progress_updated
            )

            for d in data:
                self.on_biphase_completed(*d)

            self.view.on_calculate_stopped()
            self.update_side_plots(self.get_selected_signal_pair()[0].output_data)

    def update_plots(self):
        """
        Updates all plots according to the currently selected plot type and current data.
        """
        data = self.get_selected_signal_pair()[0].output_data

        self.update_main_plot(data)
        self.update_side_plots(data)

    def update_main_plot(self, data):
        """
        Updates the main plot, plotting the wavelet transform or bispectrum depending
        on the selected plot type.

        :param data: the data object
        """
        x, y, c, log = self.get_main_plot_data(self.view.get_plot_type(), data)
        freq_x, freq_y = self.view.get_selected_freq_pair()

        if self.view.is_wt_selected():
            self.view.switch_to_dual_plot()
        else:
            self.view.switch_to_triple_plot()

        if freq_x is not None and freq_y is not None:
            self.view.plot_main.draw_crosshair(freq_x, freq_y)

        if c is not None:
            self.view.plot_main.set_log_scale(log, axis="x")
            self.view.plot_main.set_log_scale(True, axis="y")
            self.view.plot_main.update_xlabel("Frequency (Hz)")
            self.view.plot_main.update_ylabel("Frequency (Hz)")
            self.view.plot_main.plot(x=x, y=y, c=c)

    def update_side_plots(self, data: BAOutputData):
        """
        Updates the side plot(s). For wavelet transforms this will be
        the average amplitude/power, but for bispectra this will be the
        biphase and biamplitude.

        :param data: the data object
        """
        if self.view.is_wt_selected():  # Plot average amplitude or power.
            amp_not_power = self.view.is_amplitude_selected()
            x, y = self.get_side_plot_data_wt(
                self.view.get_plot_type(), data, amp_not_power
            )

            if x is not None and y is not None and len(x) > 0:
                plot = self.view.plot_right_top
                plot.set_xlabel("Average amplitude")
                plot.set_ylabel("Frequency (Hz)")

                plot.set_log_scale(True, "y")
                plot.plot(x, y)
        else:  # Plot biphase and biamplitude.
            biamp, biphase = self.get_side_plot_data_bispec(
                self.view.get_plot_type(), self.view.get_selected_freq_pair(), data
            )

            times = self.get_selected_signal().times

            if biamp is not None:
                self.view.plot_right_middle.update_ylabel("Biamplitude")
                self.view.plot_right_middle.update_xlabel("Time (s)")
                self.view.plot_right_middle.plot(times, biamp)

            if biphase is not None:
                self.view.plot_right_bottom.update_ylabel("Biphase")
                self.view.plot_right_bottom.update_xlabel("Time (s)")
                self.view.plot_right_bottom.plot(times, biphase)

    @staticmethod
    def get_side_plot_data_wt(
        plot_type: str, data: BAOutputData, amp_not_power: bool
    ) -> Tuple[ndarray, ndarray]:
        """
        Gets the data required to plot the average amplitude/power on the side plot. Used
        when a wavelet transform is selected.

        :param plot_type: the plot type shown in the QComboBox, e.g. "Wavelet transform 1"
        :param data: the data object
        :param amp_not_power: whether the data should be amplitude, or power
        :return: the x-values, the y-values
        """
        _dict = {
            "Wavelet transform 1": (
                data.avg_amp_wt1 if amp_not_power else data.avg_pow_wt1,
                data.freq,
            ),
            "Wavelet transform 2": (
                data.avg_amp_wt2 if amp_not_power else data.avg_pow_wt2,
                data.freq,
            ),
        }
        return _dict.get(plot_type)

    @staticmethod
    def get_side_plot_data_bispec(
        plot_type: str, freq: Tuple[float, float], data: BAOutputData
    ):
        """
        Gets the data required to plot the biphase and biamplitude on the side plots.
        Used when bispectrum is selected.

        :param freq: the selected frequencies (x and y)
        :param plot_type the plot type shown in the QComboBox, e.g. "b111"
        :param data the data object
        :return: biamplitude and biphase
        """
        key = ", ".join([str(f) for f in freq])

        try:
            biamp = data.biamp.get(key)
            biphase = data.biphase.get(key)
        except AttributeError:
            biamp = None
            biphase = None

        if "None" in key or biamp is None:
            return None, None

        _dict = {
            "b111": (biamp[0], biphase[0]),
            "b222": (biamp[1], biphase[1]),
            "b122": (biamp[2], biphase[2]),
            "b211": (biamp[3], biphase[3]),
        }
        biamp, biphase = _dict.get(plot_type)
        return biamp, biphase

    @staticmethod
    def get_main_plot_data(
        plot_type: str, data: BAOutputData
    ) -> Tuple[ndarray, ndarray, ndarray, bool]:
        """
        Gets the relevant arrays to plot in the main plot (WT or bispectrum).

        :param plot_type: the type of plot, as shown in the QComboBox; e.g. "Wavelet transform 1"
        :param data: the data object with all data
        :return: the x-values, the y-values, the c-values, and whether to use a log scale
        """
        if not isinstance(data, BAOutputData):
            return [None for _ in range(4)]

        _dict = {
            "Wavelet transform 1": (data.times, data.freq, data.amp_wt1, False),
            "Wavelet transform 2": (data.times, data.freq, data.amp_wt2, False),
            "b111": (data.freq, data.freq, data.bispxxx, True),
            "b222": (data.freq, data.freq, data.bispppp, True),
            "b122": (data.freq, data.freq, data.bispxpp, True),
            "b211": (data.freq, data.freq, data.bisppxx, True),
        }
        data = _dict.get(plot_type)
        if data is None:  # All plots.
            pass  # TODO

        return data

    def on_bispectrum_completed(
        self,
        name: str,
        freq: ndarray,
        amp_wt1: ndarray,
        pow_wt1: ndarray,
        avg_amp_wt1: ndarray,
        avg_pow_wt1: ndarray,
        amp_wt2: ndarray,
        pow_wt2: ndarray,
        avg_amp_wt2: ndarray,
        avg_pow_wt2: ndarray,
        bispxxx: ndarray,
        bispppp: ndarray,
        bispxpp: ndarray,
        bisppxx: ndarray,
        surrxxx: ndarray,
        surrppp: ndarray,
        surrxpp: ndarray,
        surrpxx: ndarray,
        opt: dict,
    ):

        # Attach the data to the first signal in the current pair.
        sig = self.signals.get(name)

        sig.output_data = BAOutputData(
            amp_wt1,
            pow_wt1,
            avg_amp_wt1,
            avg_pow_wt1,
            amp_wt2,
            pow_wt2,
            avg_amp_wt2,
            avg_pow_wt2,
            sig.times,
            freq,
            bispxxx,
            bispppp,
            bispxpp,
            bisppxx,
            surrxxx,
            surrppp,
            surrxpp,
            surrpxx,
            opt,
            {},
            {},
        )

    def on_biphase_completed(
        self,
        name: str,
        freq_x: float,
        freq_y: float,
        biamp1: ndarray,
        biphase1: ndarray,
        biamp2: ndarray,
        biphase2: ndarray,
        biamp3: ndarray,
        biphase3: ndarray,
        biamp4: ndarray,
        biphase4: ndarray,
    ):
        sig = self.signals.get(name)
        key = f"{freq_x}, {freq_y}"

        data = sig.output_data

        data.biamp[key] = [[] for _ in range(4)]
        data.biphase[key] = [[] for _ in range(4)]

        data.biamp[key][0] = biamp1
        data.biphase[key][0] = biphase1

        data.biamp[key][1] = biamp2
        data.biphase[key][1] = biphase2

        data.biamp[key][2] = biamp3
        data.biphase[key][2] = biphase3

        data.biamp[key][3] = biamp4
        data.biphase[key][3] = biphase4

    def load_data(self):
        """
        Loads the data from a file, showing a dialog to set the frequency of
        the signal.
        """
        self.signals = SignalPairs.from_file(self.open_file)

        if not self.signals.has_frequency():
            freq = FrequencyDialog().run_and_get()

            if freq:
                self.set_frequency(freq)
                self.on_data_loaded()
            else:
                raise Exception("Frequency was None. Perhaps it was mistyped?")

    def on_data_loaded(self):
        self.view.update_signal_listview(self.signals.get_pair_names())
        self.plot_signal_pair()

    def plot_signal_pair(self):
        self.view.plot_signal_pair(self.get_selected_signal_pair())

    def on_signal_selected(self, item: Union[QListWidgetItem, str]):
        if isinstance(item, QListWidgetItem):
            name = item.text()
        else:
            name = item

        self.signals.reset()
        if name != self.selected_signal_name:
            print(f"Selected '{name}'")
            self.selected_signal_name = name
            self.plot_signal_pair()
            self.view.on_xlim_edited()

            self.plot_preprocessed_signal()

    def get_selected_signal_pair(self) -> Tuple[TimeSeries, TimeSeries]:
        """
        Gets the currently selected signal pair as a tuple containing 2 signals.
        """
        return self.signals.get_pair_by_name(self.selected_signal_name)

    def get_params(self) -> BAParams:
        """Gets data from the GUI to create the params used by the bispectrum calculation."""
        return BAParams(
            signals=self.signals,
            preprocess=self.view.get_preprocess(),
            fmin=self.view.get_fmin(),
            fmax=self.view.get_fmax(),
            f0=self.view.get_f0(),
            nv=self.view.get_nv(),
            surr_count=self.view.get_surr_count(),
            opt={},
        )
Beispiel #12
0
class BAPresenter(BaseTFPresenter):
    def __init__(self, view):
        super().__init__(view)
        self.params: BAParams = None

        from gui.windows.bispectrum.BAWindow import BAWindow

        self.view: BAWindow = view

    def init(self):
        super().init()
        self.view.switch_to_dual_plot()

    def calculate(self, calculate_all: bool):
        """Starts the coroutine which calculates the data."""
        asyncio.ensure_future(self.coro_calculate())
        self.view.on_calculate_started()

    async def coro_calculate(self):
        """Coroutine which calculates the bispectra."""
        self.enable_save_data(False)

        if self.mp_handler:
            self.mp_handler.stop()

        self.is_plotted = False
        self.invalidate_data()

        # We need to use the same params for biphase later; save now because the UI could change.
        self.params = self.get_params()

        self.mp_handler = MPHandler()
        data = await self.mp_handler.coro_bispectrum_analysis(
            self.signals, self.params, self.on_progress_updated)

        for d in data:
            self.on_bispectrum_completed(*d)

        self.enable_save_data(True)

        self.view.on_calculate_stopped()
        self.update_plots()

    def add_point(self, x: float, y: float):
        asyncio.ensure_future(self.coro_biphase(x, y))
        self.view.on_calculate_started()

    async def coro_biphase(self, x: float, y: float):
        self.enable_save_data(False)
        self.mp_handler.stop()

        fs = self.params.fs
        f0 = self.params.f0
        fr = self.view.get_selected_freq_pair()
        x, y = fr

        if x is not None and y is not None:
            data = await self.mp_handler.coro_biphase(self.signals, fs, f0, fr,
                                                      self.on_progress_updated)

            for d in data:
                self.on_biphase_completed(*d)

            self.enable_save_data(True)

            self.view.on_calculate_stopped()
            self.update_side_plots(
                self.get_selected_signal_pair()[0].output_data)

    def update_plots(self):
        """
        Updates all plots according to the currently selected plot type and current data.
        """
        data = self.get_selected_signal_pair()[0].output_data

        try:
            self.update_main_plot(data)
            self.update_side_plots(data)
        except AttributeError:
            pass
        except ValueError as e:
            msg = "zero-size array to reduction operation minimum which has no identity"
            if msg in str(e):
                print(
                    f"'{self.view.combo_plot_type.currentText()}' is not available to plot."
                )
            else:
                raise e

    def set_plot_type(self, amplitude_selected=True) -> None:
        """
        Set the type of plot to display (power or amplitude). This affects
        the main plot and the amplitude plot.

        :param amplitude_selected: whether to set the plot type as amplitude (not power)
        """
        self.plot_ampl = amplitude_selected
        self.update_plots()

    def update_main_plot(self, data):
        """
        Updates the main plot, plotting the wavelet transform or bispectrum depending
        on the selected plot type.

        :param data: the data object
        """
        x, y, c, log = self.get_main_plot_data(self.view.get_plot_type(), data)
        freq_x, freq_y = self.view.get_selected_freq_pair()

        if self.view.is_wt_selected():
            self.view.switch_to_dual_plot()
        else:
            self.view.switch_to_triple_plot()

        if freq_x is not None and freq_y is not None:
            self.view.plot_main.draw_crosshair(freq_x, freq_y)

        if c is not None:
            self.view.plot_main.set_log_scale(log, axis="x")
            self.view.plot_main.set_log_scale(True, axis="y")
            self.view.plot_main.update_xlabel("Frequency (Hz)")
            self.view.plot_main.update_ylabel("Frequency (Hz)")
            self.view.plot_main.plot(x=x, y=y, c=c)

    def update_side_plots(self, data: BAOutputData):
        """
        Updates the side plot(s). For wavelet transforms this will be
        the average amplitude/power, but for bispectra this will be the
        biphase and biamplitude.

        :param data: the data object
        """
        if self.view.is_wt_selected():  # Plot average amplitude or power.
            x, y = self.get_side_plot_data_wt(self.view.get_plot_type(), data,
                                              self.plot_ampl)

            if x is not None and y is not None and len(x) > 0:
                plot = self.view.plot_right_top
                plot.set_xlabel("Average amplitude")
                plot.set_ylabel("Frequency (Hz)")

                plot.set_log_scale(True, "y")
                plot.plot(x, y)
        else:  # Plot biphase and biamplitude.
            biamp, biphase = self.get_side_plot_data_bispec(
                self.view.get_plot_type(), self.view.get_selected_freq_pair(),
                data)

            times = self.get_selected_signal().times

            if biamp is not None:
                self.view.plot_right_middle.update_ylabel("Biamplitude")
                self.view.plot_right_middle.update_xlabel("Time (s)")
                self.view.plot_right_middle.plot(times, biamp)

            if biphase is not None:
                self.view.plot_right_bottom.update_ylabel("Biphase")
                self.view.plot_right_bottom.update_xlabel("Time (s)")
                self.view.plot_right_bottom.plot(times, biphase)

    @staticmethod
    def get_side_plot_data_wt(plot_type: str, data: BAOutputData,
                              amp_not_power: bool) -> Tuple[ndarray, ndarray]:
        """
        Gets the data required to plot the average amplitude/power on the side plot. Used
        when a wavelet transform is selected.

        :param plot_type: the plot type shown in the QComboBox, e.g. "Wavelet transform 1"
        :param data: the data object
        :param amp_not_power: whether the data should be amplitude, or power
        :return: the x-values, the y-values
        """
        _dict = {
            "Wavelet transform 1": (
                data.avg_amp_wt1 if amp_not_power else data.avg_pow_wt1,
                data.freq,
            ),
            "Wavelet transform 2": (
                data.avg_amp_wt2 if amp_not_power else data.avg_pow_wt2,
                data.freq,
            ),
        }
        return _dict.get(plot_type)

    @staticmethod
    def get_side_plot_data_bispec(plot_type: str, freq: Tuple[float, float],
                                  data: BAOutputData):
        """
        Gets the data required to plot the biphase and biamplitude on the side plots.
        Used when bispectrum is selected.

        :param freq: the selected frequencies (x and y)
        :param plot_type the plot type shown in the QComboBox, e.g. "b111"
        :param data the data object
        :return: biamplitude and biphase
        """
        key = ", ".join([str(f) for f in freq])

        try:
            biamp = data.biamp.get(key)
            biphase = data.biphase.get(key)
        except AttributeError:
            biamp = None
            biphase = None

        if "None" in key or biamp is None:
            return None, None

        _dict = {
            "b111": (biamp[0], biphase[0]),
            "b222": (biamp[1], biphase[1]),
            "b122": (biamp[2], biphase[2]),
            "b211": (biamp[3], biphase[3]),
        }
        biamp, biphase = _dict.get(plot_type)
        return biamp, biphase

    def get_main_plot_data(
            self, plot_type: str,
            data: BAOutputData) -> Tuple[ndarray, ndarray, ndarray, bool]:
        """
        Gets the relevant arrays to plot in the main plot (WT or bispectrum).

        :param plot_type: the type of plot, as shown in the QComboBox; e.g. "Wavelet transform 1"
        :param data: the data object with all data
        :return: the x-values, the y-values, the c-values, and whether to use a log scale
        """
        if not isinstance(data, BAOutputData):
            return [None for _ in range(4)]

        plot_surr = self.view.get_plot_surrogates_selected()

        if self.plot_ampl:
            wt1 = data.amp_wt1
            wt2 = data.amp_wt2
        else:
            wt1 = data.pow_wt1
            wt2 = data.pow_wt2

        _dict = {
            "Wavelet transform 1": (data.times, data.freq, wt1, False),
            "Wavelet transform 2": (data.times, data.freq, wt2, False),
            "b111": (data.freq, data.freq, data.bispxxx, True),
            "b222": (data.freq, data.freq, data.bispppp, True),
            "b122": (data.freq, data.freq, data.bispxpp, True),
            "b211": (data.freq, data.freq, data.bisppxx, True),
        }
        tup = _dict.get(plot_type)
        if tup is None:  # All plots.
            pass  # TODO

        if plot_surr and self.params.surr_count > 0 and "transform" not in plot_type:
            tup = self.apply_surrogates(plot_type, data, tup)

        return tup

    def apply_surrogates(
        self,
        plot_type: str,
        data: BAOutputData,
        tup: Tuple[ndarray, ndarray, ndarray, bool],
    ) -> Tuple[ndarray, ndarray, ndarray, bool]:
        fx, fy, bisp, b = tup

        K = np.int(np.floor((self.params.surr_count + 1) * self.params.alpha))
        surr = {
            "b111": data.surrxxx,
            "b222": data.surrppp,
            "b122": data.surrxpp,
            "b211": data.surrpxx,
        }.get(plot_type)[:, :, K]

        bisp = bisp.copy()

        bisp -= surr
        bisp[bisp < 0] = np.NAN

        return fx, fy, bisp, b

    def on_bispectrum_completed(
        self,
        name: str,
        freq: ndarray,
        amp_wt1: ndarray,
        pow_wt1: ndarray,
        avg_amp_wt1: ndarray,
        avg_pow_wt1: ndarray,
        amp_wt2: ndarray,
        pow_wt2: ndarray,
        avg_amp_wt2: ndarray,
        avg_pow_wt2: ndarray,
        bispxxx: ndarray,
        bispppp: ndarray,
        bispxpp: ndarray,
        bisppxx: ndarray,
        surrxxx: ndarray,
        surrppp: ndarray,
        surrxpp: ndarray,
        surrpxx: ndarray,
        opt: Dict,
    ) -> None:
        self.opt: Dict = opt

        # Attach the data to the first signal in the current pair.
        sig = self.signals.get(name)

        sig.output_data = BAOutputData(
            amp_wt1,
            pow_wt1,
            avg_amp_wt1,
            avg_pow_wt1,
            amp_wt2,
            pow_wt2,
            avg_amp_wt2,
            avg_pow_wt2,
            sig.times,
            freq,
            bispxxx,
            bispppp,
            bispxpp,
            bisppxx,
            surrxxx,
            surrppp,
            surrxpp,
            surrpxx,
            opt,
            {},
            {},
        )

    def on_biphase_completed(
        self,
        name: str,
        freq_x: float,
        freq_y: float,
        biamp1: ndarray,
        biphase1: ndarray,
        biamp2: ndarray,
        biphase2: ndarray,
        biamp3: ndarray,
        biphase3: ndarray,
        biamp4: ndarray,
        biphase4: ndarray,
    ) -> None:
        sig = self.signals.get(name)
        key = f"{freq_x}, {freq_y}"

        data = sig.output_data

        data.biamp[key] = [[] for _ in range(4)]
        data.biphase[key] = [[] for _ in range(4)]

        data.biamp[key][0] = biamp1
        data.biphase[key][0] = biphase1

        data.biamp[key][1] = biamp2
        data.biphase[key][1] = biphase2

        data.biamp[key][2] = biamp3
        data.biphase[key][2] = biphase3

        data.biamp[key][3] = biamp4
        data.biphase[key][3] = biphase4

    @override
    async def coro_get_data_to_save(self) -> Dict:
        if not self.opt or not self.params:
            return

        output_data: List[BAOutputData] = [
            s.output_data for s, _ in self.signals.get_pairs()
        ]
        cols = len(output_data)

        first = output_data[0]

        amp = np.empty((*first.amp_wt1.shape, cols * 2))
        avg_amp = np.empty((first.avg_amp_wt1.shape[0], cols * 2))

        b111 = np.empty((*first.bispxxx.shape, cols))
        b222 = np.empty(b111.shape)
        b122 = np.empty(b111.shape)
        b211 = np.empty(b111.shape)

        freq = first.freq
        time = first.times
        preproc = []  # TODO: Save this and other params

        for index in range(0, len(output_data) * 2, 2):
            d = output_data[index]

            amp[:, :, index] = d.amp_wt1[:]
            avg_amp[:, index] = d.avg_amp_wt1[:]

            amp[:, :, index + 1] = d.amp_wt2[:]
            avg_amp[:, index + 1] = d.avg_amp_wt2[:]

            b111[:, :, index] = d.bispxxx
            b222[:, :, index] = d.bispppp
            b122[:, :, index] = d.bispxpp
            b211[:, :, index] = d.bisppxx

        ba_data = {
            "amplitude": amp,
            "avg_amplitude": avg_amp,
            "frequency": freq,
            "time": time,
            "preprocessed_signals": preproc,
            "b111": b111,
            "b222": b222,
            "b122": b122,
            "b211": b211,
            "fr": self.view.get_f0() or 1,
            "fmin": self.opt["fmin"],
            "fmax": self.opt["fmax"],
            "preprocessing": "on" if self.params.preprocess else "off",
            "sampling_frequency": self.params.fs,
        }
        return {"BAData": sanitise(ba_data)}

    def load_data(self):
        """
        Loads the data from a file, showing a dialog to set the frequency of
        the signal.
        """
        self.signals = SignalPairs.from_file(self.open_file)

        if not self.signals.has_frequency():
            freq = FrequencyDialog().run_and_get()

            if freq:
                self.set_frequency(freq)
                self.on_data_loaded()
            else:
                raise Exception("Frequency was None. Perhaps it was mistyped?")

    def on_data_loaded(self):
        self.view.update_signal_listview(self.signals.get_pair_names())
        self.plot_signal_pair()

    def plot_signal_pair(self):
        self.view.plot_signal_pair(self.get_selected_signal_pair())

    def on_signal_selected(self, item: Union[QListWidgetItem, str]):
        if isinstance(item, QListWidgetItem):
            name = item.text()
        else:
            name = item

        self.signals.reset()
        if name != self.selected_signal_name:
            print(f"Selected '{name}'")
            self.selected_signal_name = name
            self.plot_signal_pair()
            self.view.on_xlim_edited()

            self.plot_preprocessed_signal()

    def get_selected_signal_pair(self) -> Tuple[TimeSeries, TimeSeries]:
        """
        Gets the currently selected signal pair as a tuple containing 2 signals.
        """
        return self.signals.get_pair_by_name(self.selected_signal_name)

    def get_params(self) -> BAParams:
        """Gets data from the GUI to create the params used by the bispectrum calculation."""
        return BAParams(
            signals=self.signals,
            preprocess=self.view.get_preprocess(),
            fmin=self.view.get_fmin(),
            fmax=self.view.get_fmax(),
            f0=self.view.get_f0(),
            nv=self.view.get_nv(),
            surr_count=self.view.get_surr_count(),
            alpha=self.view.get_alpha(),
            opt={},
        )
Beispiel #13
0
class TFPresenter(BaseTFPresenter):
    """
    The presenter in control of the time-frequency window.
    """
    def __init__(self, view):
        super(TFPresenter, self).__init__(view)

        from gui.windows.timefrequency import TFWindow

        self.view: TFWindow = view
        self.is_calculating_all = True

    def calculate(self, calculate_all: bool):
        """
        Calculates the desired transform(s), and plots the result.
        """
        # If WFT parameters are incorrect, show error.
        if not self.view.get_fmin(
        ) and not self.view.is_wavelet_transform_selected():
            return self.view.show_wft_error()

        asyncio.ensure_future(self.coro_calculate(calculate_all))
        print("Started calculation...")

    async def coro_calculate(self, calc_all: bool):
        """
        Coroutine to calculate all results.
        """
        self.is_calculating_all = calc_all

        self.mp_handler = MPHandler()
        self.mp_handler.stop()

        params = self.get_params(all_signals=calc_all)
        if params.transform == _wft:
            if self.view.get_fmin() is None:
                raise Exception("Minimum frequency must be defined for WFT.")

        self.is_plotted = False
        self.view.main_plot().clear()
        self.view.main_plot().set_in_progress(True)
        self.invalidate_data()

        log: bool = (params.transform == _wt)
        self.view.main_plot().set_log_scale(logarithmic=log)
        self.view.amplitude_plot().set_log_scale(logarithmic=log)

        self.view.on_calculate_started()

        all_data = await self.mp_handler.coro_transform(
            params, self.on_progress_updated)

        for d in all_data:
            self.on_transform_completed(*d)

    def on_transform_completed(self, name, times, freq, values, ampl, powers,
                               avg_ampl, avg_pow):
        """Called when the calculation of the desired transform(s) is completed."""
        self.view.on_calculate_stopped()

        t = self.signals.get(name)
        t.output_data = TFOutputData(times, values, ampl, freq, powers,
                                     avg_ampl, avg_pow)

        print(f"Finished calculation for '{name}'.")

        # Plot result if all signals finished.
        if self.all_transforms_completed():
            self.on_all_transforms_completed()

    def all_transforms_completed(self):
        """Returns whether all transforms have been completed."""
        return all([s.output_data.is_valid() for s in self.signals_calc])

    def on_all_transforms_completed(self):
        """Called when all transforms have been completed."""
        self.plot_output()
        self.on_all_tasks_completed()

    def get_values_to_plot(self, amplitude=None) -> tuple:
        """
        Returns the data needed to plotting the transform.
        :param amplitude: overrides the normal value of whether to plotting amplitude instead of power
        :return: the times, frequencies, amplitudes/powers, and average amplitudes/powers
        """
        amp: bool = self.plot_ampl
        if amplitude is not None:
            amp = amplitude

        tf_data = self.get_selected_signal().output_data
        if not tf_data.is_valid():
            return None, None, None, None

        if amp:  # Plot amplitudes.
            values = tf_data.ampl
            avg_values = tf_data.avg_ampl
        else:  # Plot powers.
            values = tf_data.powers
            avg_values = tf_data.avg_pow

        return tf_data.times, tf_data.freq, values, avg_values

    def set_plot_type(self, amplitude_selected=True):
        """
        Set the type of plotting to display (power or amplitude). This affects
        the main plotting and the amplitude plotting.

        :param amplitude_selected: whether to set the plotting type as amplitude (not power)
        """
        self.plot_ampl = amplitude_selected
        if self.is_plotted:
            t, f, values, avg_values = self.get_values_to_plot()
            self.plot(t, f, values, avg_values)

    def plot(self, times, freq, values, avg_values):
        self.view.main_plot().plot(times, values, freq)
        self.view.amplitude_plot().plot(avg_values, freq)

    def plot_output(self):
        """
        Plots the output of the WT/WFT calculations for the currently selected signal.
        """
        t, f, values, avg_values = self.get_values_to_plot()

        if t is not None and f is not None:
            self.plot(t, f, values, avg_values)
            self.is_plotted = True
        else:
            self.view.main_plot().clear()
            self.view.amplitude_plot().clear()

    def get_params(self, all_signals=True) -> TFParams:
        """
        Creates the parameters to use when performing the calculations.
        """
        if all_signals:
            self.signals_calc = self.signals
        else:
            self.signals_calc = self.signals.only(self.selected_signal_name)

        return create(
            params_type=TFParams,
            signals=self.signals_calc,
            fmin=self.view.get_fmin(),
            fmax=self.view.get_fmax(),
            f0=self.view.get_f0(),
            # Only one of these two will be used, depending on the selected transform.
            window=self.view.get_wt_wft_type(),
            wavelet=self.view.get_wt_wft_type(),
            cut_edges=self.view.get_cut_edges(),
            preprocess=self.view.get_preprocess(),
            transform=self.view.get_transform_type(),
        )

    def load_data(self):
        self.signals = Signals.from_file(self.open_file)

        if not self.signals.has_frequency():
            freq = FrequencyDialog().run_and_get()

            if freq:
                self.set_frequency(freq)
                self.on_data_loaded()

    def plot_signal(self):
        """Plots the signal on the SignalPlot."""
        self.view.plot_signal(self.get_selected_signal())

    def on_data_loaded(self):
        """Called when the time-series data has been loaded."""
        self.view.update_signal_listview(self.signals.names())
        self.plot_signal()

    def on_signal_zoomed(self, rect: Rect) -> None:
        """
        Override callback to also plot the preprocessed version of
        the x-limited signal instead of the whole signal.

        :param rect: the rectangle which has been zoomed to
        """
        super().on_signal_zoomed(rect)
        self.plot_preprocessed_signal()

    def on_signal_selected(self, item: Union[QListWidgetItem, str]):
        """
        Called when a signal is selected in the QListWidget.
        Plots the new signal in the top-left plotting and, if
        transform data is available, plots the transform and
        amplitude/power in the main plots.
        """
        if isinstance(item, QListWidgetItem):
            name = item.text()
        else:
            name = item

        self.signals.reset()
        if name != self.selected_signal_name:
            print(f"Selected signal: '{name}'")
            self.selected_signal_name = name
            self.plot_signal()
            self.view.on_xlim_edited()
            self.plot_output()

            self.plot_preprocessed_signal()