Пример #1
0
    def _point_plots_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """

        self.point_plotdata = ArrayPlotData(
            peak_times=self.physiodata.peak_times.flatten(),
            x_times=self.physiodata.x_indices - self.physiodata.dzdt_pre_peak,
            lvet=self.physiodata.lvet,
            pep=self.physiodata.pep)

        container = VPlotContainer(resizable="hv",
                                   bgcolor="lightgray",
                                   fill_padding=True,
                                   padding=10)

        for sig in ("pep", "x_times", "lvet"):
            temp_plt = Plot(self.point_plotdata)
            temp_plt.plot(("peak_times", sig), line_width=2)
            temp_plt.title = sig
            container.add(temp_plt)
        container.padding_top = 10

        return container
Пример #2
0
    def _karcher_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        self.interactive = True
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.karcher_plotdata = ArrayPlotData(
            time=self.global_ensemble.dzdt_time,
            unregistered_mean=self.global_ensemble.dzdt_signal,
            karcher_mean=karcher_mean,
            karcher_time=self.dzdt_karcher_mean_time)
        karcher_plot = Plot(self.karcher_plotdata)
        karcher_plot.plot(("time", "unregistered_mean"),
                          line_width=1,
                          color="lightblue")
        line_plot = karcher_plot.plot(("karcher_time", "karcher_mean"),
                                      line_width=3,
                                      color="blue")[0]
        # Create an overlay tool
        karcher_plot.datasources['karcher_time'].sort_order = "ascending"

        return karcher_plot
Пример #3
0
 def _z0_plot_default(self):
     self.z0_plot_data = ArrayPlotData(time=self.time,
                                       raw_data=self.raw_z0_signal,
                                       cleaned_data=self.resp_corrected_z0)
     plot = Plot(self.z0_plot_data)
     plot.plot(("time", "raw_data"), color="blue", line_width=1)
     plot.plot(("time", "cleaned_data"), color="green", line_width=1)
     plot.title = "z0"
     plot.title_position = "right"
     plot.title_angle = 270
     plot.padding = 20
     return plot
Пример #4
0
 def _image_plot_default(self):
     # for image plot
     img = self.ecg_matrix
     if self.ecg_matrix.size == 0:
         img = np.zeros((100, 100))
     self.image_plot_data = ArrayPlotData(imagedata=img)
     plot = Plot(self.image_plot_data)
     self.image_selection_tool = ImageRangeSelector(component=plot,
                                                    tool_mode="range",
                                                    axis="value",
                                                    always_on=True)
     plot.img_plot("imagedata",
                   colormap=jet,
                   name="plot1",
                   origin="bottom left")[0]
     plot.overlays.append(self.image_selection_tool)
     return plot
Пример #5
0
 def _dzdt_plot_default(self):
     """
     Creates a plot of the ecg_ts data and the signals derived during
     the Pan Tomkins algorithm
     """
     # Create plotting components
     self.dzdt_plot_data = ArrayPlotData(
         time=self.time,
         raw_data=self.raw_dzdt_signal,
         cleaned_data=self.resp_corrected_dzdt)
     plot = Plot(self.dzdt_plot_data)
     plot.plot(("time", "raw_data"), color="blue", line_width=1)
     plot.plot(("time", "cleaned_data"), color="green", line_width=1)
     plot.title = "dZ/dt"
     plot.title_position = "right"
     plot.title_angle = 270
     plot.padding = 20
     return plot
Пример #6
0
    def _registration_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.single_registration_plotdata = ArrayPlotData(
            karcher_time=self.dzdt_karcher_mean_time,
            karcher_mean=karcher_mean,
            registered_func=karcher_mean)

        self.single_plot = Plot(self.single_registration_plotdata)
        self.single_plot.plot(("karcher_time", "karcher_mean"),
                              line_width=2,
                              color="blue")
        self.single_plot.plot(("karcher_time", "registered_func"),
                              line_width=2,
                              color="maroon")


        if self.all_beats_registered_to_initial or \
           self.all_beats_registered_to_mode:
            image_data = self.registered_functions.copy()
        else:
            image_data = self.dzdt_functions_to_warp.copy()

        # Create a plot of all the functions registered or not
        self.all_registration_plotdata = ArrayPlotData(image_data=image_data)
        self.all_plot = Plot(self.all_registration_plotdata)
        self.all_plot.img_plot("image_data", colormap=jet)

        return HPlotContainer(self.single_plot, self.all_plot)
Пример #7
0
    def _plot_default(self):
        # Create plotting components
        plotdata = ArrayPlotData(time=self.time, data=self.data)
        plot = Plot(plotdata)
        self.renderer = plot.plot(("time", "data"), color=self.line_color)[0]
        plot.title = self.name
        plot.title_position = "right"
        plot.title_angle = 270
        plot.line_width = 1
        plot.padding = 25
        plot.width = 400

        # Load the censor regions and add them to the plot
        for censor_region in self.censored_regions:
            # Make a censor region on the Timeseries object
            censor_region.plot = self.renderer
            censor_region.viz
            censor_region.set_limits(censor_region.start_time,
                                     censor_region.end_time)
        return plot
Пример #8
0
    def _plot_default(self):
        """
        Creates a plot of the ecg_ts data and the signals derived during
        the Pan Tomkins algorithm
        """
        # Create plotting components
        self.plot_data = ArrayPlotData(
            time=self.ecg_time,
            raw_ecg=self.ecg_signal,
            qrs_power=self.qrs_power_signal,
            aux_signal=self.aux_signal,
            # Ensembleable peaks
            peak_times=self.peak_times,
            peak_values=self.peak_values,
            # Non-ensembleable, useful for HR, HRV
            dne_peak_times=self.dne_peak_times,
            dne_peak_values=self.dne_peak_values,
            # Threshold slider bar
            thr_times=np.array([self.ecg_time[0], self.ecg_time[-1]]),
            thr_vals=np.array([0, 0]))

        plot = Plot(self.plot_data, use_backbuffer=True)
        self.aux_trace = plot.plot(("time", "aux_signal"),
                                   color="purple",
                                   alpha=0.6,
                                   line_width=2)[0]
        self.qrs_power_trace = plot.plot(("time", "qrs_power"),
                                         color="blue",
                                         alpha=0.8)[0]
        self.ecg_trace = plot.plot(("time", "raw_ecg"),
                                   color="red",
                                   line_width=2)[0]

        # Load the censor regions and add them to the plot
        for n, (start, end) in enumerate(self.censored_intervals):
            censor_key = "censor%03d" % n
            self.censored_region_graphics.append(
                StaticCensoredInterval(component=self.qrs_power_trace,
                                       metadata_name=censor_key))
            self.ecg_trace.overlays.append(self.censored_region_graphics[-1])
            self.ecg_trace.index.metadata[censor_key] = start, end

        # Line showing
        self.threshold_trace = plot.plot(("thr_times", "thr_vals"),
                                         color="black",
                                         line_width=3)[0]

        # Plot for plausible peaks
        self.ppeak_vis = plot.plot(("dne_peak_times", "dne_peak_values"),
                                   type="scatter",
                                   marker="diamond")[0]

        # Make a scatter plot where you can edit the peaks
        self.peak_vis = plot.plot(("peak_times", "peak_values"),
                                  type="scatter")[0]
        self.scatter = plot.components[-1]
        self.peak_editor = PeakPickingTool(self.scatter)
        self.scatter.tools.append(self.peak_editor)
        # when the user adds or removes a point, automatically extract
        self.on_trait_event(self._change_peaks, "peak_editor.done_selecting")

        self.scatter.overlays.append(
            PeakPickingOverlay(component=self.scatter))

        return plot
Пример #9
0
 def _resp_plot_default(self):
     # Create plotting components
     self.resp_plot_data = ArrayPlotData(
         time=self.time,
         inhale_times=self.resp_inhale_begin_times,
         inhale_values=self.resp_inhale_begin_values,
         exhale_times=self.resp_exhale_begin_times,
         exhale_values=self.resp_exhale_begin_values,
         filtered_resp=self.resp_signal,
         lpfilt_resp=self.lpfilt_resp)
     plot = Plot(self.resp_plot_data)
     plot.plot(("time", "filtered_resp"), color="blue", line_width=1)
     plot.plot(("time", "lpfilt_resp"), color="green", line_width=1)
     # Plot the inhalation peaks
     plot.plot(("inhale_times", "inhale_values"),
               type="scatter",
               marker="square")
     plot.plot(("exhale_times", "exhale_values"),
               type="scatter",
               marker="circle")
     plot.title = "Respiration"
     plot.title_position = "right"
     plot.title_angle = 270
     plot.padding = 20
     return plot
Пример #10
0
class GroupRegisterDZDT(HasTraits):
    physiodata = Instance(PhysioData)

    global_ensemble = Instance(GlobalEnsembleAveragedHeartBeat)
    srvf_lambda = DelegatesTo("physiodata")
    srvf_max_karcher_iterations = DelegatesTo("physiodata")
    srvf_update_min = DelegatesTo("physiodata")
    srvf_karcher_mean_subset_size = DelegatesTo("physiodata")
    dzdt_srvf_karcher_mean = DelegatesTo("physiodata")
    dzdt_karcher_mean = DelegatesTo("physiodata")
    dzdt_warping_functions = DelegatesTo("physiodata")
    srvf_use_moving_ensembled = DelegatesTo("physiodata")
    bspline_before_warping = DelegatesTo("physiodata")
    dzdt_functions_to_warp = DelegatesTo("physiodata")

    # Configures the slice of time to be registered
    srvf_t_min = DelegatesTo("physiodata")
    srvf_t_max = DelegatesTo("physiodata")
    dzdt_karcher_mean_time = DelegatesTo("physiodata")
    dzdt_mask = Array

    # Holds indices of beats used to calculate initial Karcher mean
    dzdt_karcher_mean_inputs = DelegatesTo("physiodata")
    dzdt_karcher_mean_over_iterations = DelegatesTo("physiodata")
    dzdt_num_inputs_to_group_warping = DelegatesTo("physiodata")
    srvf_iteration_distances = DelegatesTo("physiodata")
    srvf_iteration_energy = DelegatesTo("physiodata")

    # For calculating multiple modes
    all_beats_registered_to_initial = Bool(False)
    all_beats_registered_to_mode = Bool(False)
    n_modes = DelegatesTo("physiodata")
    max_kmeans_iterations = DelegatesTo("physiodata")
    mode_dzdt_karcher_means = DelegatesTo("physiodata")
    mode_cluster_assignment = DelegatesTo("physiodata")
    mode_dzdt_srvf_karcher_means = DelegatesTo("physiodata")

    # graphics items
    karcher_plot = Instance(Plot, transient=True)
    registration_plot = Instance(HPlotContainer, transient=True)
    karcher_plotdata = Instance(ArrayPlotData, transient=True)

    # Buttons
    b_calculate_karcher_mean = Button(label="Calculate Karcher Mean")
    b_align_all_beats = Button(label="Warp all")
    b_find_modes = Button(label="Discover Modes")
    b_edit_modes = Button(label="Score Modes")
    interactive = Bool(False)

    # Holds the karcher modes
    mode_beat_train = Instance(ModeKarcherBeatTrain)
    edit_listening = Bool(False,
                          desc="If true, update_plots is called"
                          " when a beat gets hand labeled")

    def __init__(self, **traits):
        super(GroupRegisterDZDT, self).__init__(**traits)

        # Set the initial path to whatever's in the physiodata file
        logger.info("Initializing dZdt registration")

        # Is there already a Karcher mean in the physiodata?
        self.karcher_mean_calculated = self.dzdt_srvf_karcher_mean.size > 0 \
                                and self.dzdt_karcher_mean.size > 0

        # Process dZdt data before srvf analysis
        self._update_original_functions()

        self.n_functions = self.dzdt_functions_to_warp.shape[0]
        self.n_samples = self.dzdt_functions_to_warp.shape[1]
        self._init_warps()

        self.select_new_samples()

    def _mode_beat_train_default(self):
        logger.info("creating default mea_beat_train")
        assert self.physiodata is not None
        mkbt = ModeKarcherBeatTrain(physiodata=self.physiodata)
        return mkbt

    def _init_warps(self):
        """ For loading warps from mea.mat """
        if self.dzdt_warping_functions.shape[0] == \
                                self.dzdt_functions_to_warp.shape[0]:
            self.all_beats_registered_to_mode = True
            self._forward_warp_beats()

    def _global_ensemble_default(self):
        return GlobalEnsembleAveragedHeartBeat(physiodata=self.physiodata)

    def _karcher_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        self.interactive = True
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.karcher_plotdata = ArrayPlotData(
            time=self.global_ensemble.dzdt_time,
            unregistered_mean=self.global_ensemble.dzdt_signal,
            karcher_mean=karcher_mean,
            karcher_time=self.dzdt_karcher_mean_time)
        karcher_plot = Plot(self.karcher_plotdata)
        karcher_plot.plot(("time", "unregistered_mean"),
                          line_width=1,
                          color="lightblue")
        line_plot = karcher_plot.plot(("karcher_time", "karcher_mean"),
                                      line_width=3,
                                      color="blue")[0]
        # Create an overlay tool
        karcher_plot.datasources['karcher_time'].sort_order = "ascending"

        return karcher_plot

    def _registration_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.single_registration_plotdata = ArrayPlotData(
            karcher_time=self.dzdt_karcher_mean_time,
            karcher_mean=karcher_mean,
            registered_func=karcher_mean)

        self.single_plot = Plot(self.single_registration_plotdata)
        self.single_plot.plot(("karcher_time", "karcher_mean"),
                              line_width=2,
                              color="blue")
        self.single_plot.plot(("karcher_time", "registered_func"),
                              line_width=2,
                              color="maroon")


        if self.all_beats_registered_to_initial or \
           self.all_beats_registered_to_mode:
            image_data = self.registered_functions.copy()
        else:
            image_data = self.dzdt_functions_to_warp.copy()

        # Create a plot of all the functions registered or not
        self.all_registration_plotdata = ArrayPlotData(image_data=image_data)
        self.all_plot = Plot(self.all_registration_plotdata)
        self.all_plot.img_plot("image_data", colormap=jet)

        return HPlotContainer(self.single_plot, self.all_plot)

    def _forward_warp_beats(self):
        """ Create registered beats to plot, since they're not stored """
        pass
        logger.info("Applying warps to functions for plotting")
        # Re-create gam
        gam = self.dzdt_warping_functions - self.srvf_t_min
        gam = gam / (self.srvf_t_max - self.srvf_t_min)

        aligned_functions = np.zeros_like(self.dzdt_functions_to_warp)
        t = self.dzdt_karcher_mean_time
        for k in range(self.n_functions):
            aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t,
                                             self.dzdt_functions_to_warp[k])

        self.registered_functions = aligned_functions

    @on_trait_change("dzdt_num_inputs_to_group_warping")
    def select_new_samples(self):
        nbeats = self.dzdt_functions_to_warp.shape[0]
        nsamps = min(self.dzdt_num_inputs_to_group_warping, nbeats)
        self.dzdt_karcher_mean_inputs = np.random.choice(nbeats,
                                                         size=nsamps,
                                                         replace=False)

    @on_trait_change(
        ("physiodata.srvf_lambda, "
         "physiodata.srvf_update_min, physiodata.srvf_use_moving_ensembled, "
         "physiodata.srvf_karcher_mean_subset_size, "
         "physiodata.bspline_before_warping"))
    def params_edited(self):
        self.dirty = True
        self.karcher_mean_calculated = False
        self.all_beats_registered_to_initial = False
        self.all_beats_registered_to_mode = False
        self._update_original_functions()

    def _update_original_functions(self):
        logger.info("updating time slice and functions to register")

        # Get the time relative to R
        dzdt_time = np.arange(self.physiodata.dzdt_matrix.shape[1],dtype=np.float) \
                    - self.physiodata.dzdt_pre_peak
        self.dzdt_mask = (dzdt_time >= self.srvf_t_min) * (dzdt_time <=
                                                           self.srvf_t_max)
        srvf_time = dzdt_time[self.dzdt_mask]
        self.dzdt_karcher_mean_time = srvf_time

        # Extract corresponding data
        self.dzdt_functions_to_warp = self.physiodata.mea_dzdt_matrix[
                                    : , self.dzdt_mask] if \
                self.srvf_use_moving_ensembled else self.physiodata.dzdt_matrix[
                                    : , self.dzdt_mask]

        if self.bspline_before_warping:
            logger.info("Smoothing inputs with B-Splines")
            self.dzdt_functions_to_warp = np.row_stack([ UnivariateSpline(
                self.dzdt_karcher_mean_time, func, s=0.05)(self.dzdt_karcher_mean_time) \
                for func in self.dzdt_functions_to_warp]
            )
        if self.interactive:
            self.all_registration_plotdata.set_data(
                "image_data", self.dzdt_functions_to_warp.copy())
            self.all_plot.request_redraw()

    def _b_calculate_karcher_mean_fired(self):
        self.calculate_karcher_mean()

    def calculate_karcher_mean(self):
        """
        Calculates an initial Karcher Mean.
        """
        reg_prob = RegistrationProblem(
            self.dzdt_functions_to_warp[self.dzdt_karcher_mean_inputs].T,
            sample_times=self.dzdt_karcher_mean_time,
            max_karcher_iterations=self.srvf_max_karcher_iterations,
            lambda_value=self.srvf_lambda,
            update_min=self.srvf_update_min)
        reg_prob.run_registration_parallel()
        reg_problem = reg_prob
        self.dzdt_karcher_mean = reg_problem.function_karcher_mean
        self.dzdt_srvf_karcher_mean = reg_problem.srvf_karcher_mean
        self.karcher_mean_calculated = True

        # Update plots if this is interactive
        if self.interactive:
            self.karcher_plotdata.set_data("karcher_mean",
                                           self.dzdt_karcher_mean)
            self.karcher_plot.request_redraw()
            self.single_registration_plotdata.set_data("karcher_mean",
                                                       self.dzdt_karcher_mean)
            self.single_plot.request_redraw()
        self.rp = reg_problem

    def _b_align_all_beats_fired(self):
        self.align_all_beats_to_initial()

    def _b_find_modes_fired(self):
        self.detect_modes()

    def detect_modes(self):
        """
        Uses the SRD-based clustering method described in Kurtek 2017
        """
        if not self.karcher_mean_calculated:
            fail("Must calculate an initial Karcher mean first")
            return
        if not self.all_beats_registered_to_initial:
            fail("All beats must be registered to the initial Karcher mean")
            return

        # Calculate the SRDs
        dzdt_functions_to_warp = self.dzdt_functions_to_warp.T
        warps = self.dzdt_warping_functions.T
        wmax = warps.max()
        wmin = warps.min()
        warps = (warps - wmin) / (wmax - wmin)
        densities = np.diff(warps, axis=0)
        SRDs = np.sqrt(densities)
        # pairwise distances
        srd_pairwise = pairwise_distances(SRDs.T, metric=fisher_rao_dist)
        tri = srd_pairwise[np.triu_indices_from(srd_pairwise, 1)]
        linkage = complete(tri)

        # Performs an iteration of k-means
        def cluster_karcher_means(initial_assignments):
            cluster_means = {}
            cluster_ids = np.unique(initial_assignments).tolist()
            warping_functions = np.zeros_like(SRDs)

            # Calculate a Karcher mean for each cluster
            for cluster_id in cluster_ids:
                print "Cluster ID:", cluster_id
                cluster_id_mask = initial_assignments == cluster_id
                cluster_srds = SRDs[:, cluster_id_mask]

                # If there is only a single SRD in this cluster, it is the mean
                if cluster_id_mask.sum() == 1:
                    cluster_means[cluster_id] = cluster_srds
                    continue

                # Run group registration to get Karcher mean
                cluster_reg = RegistrationProblem(
                    cluster_srds,
                    sample_times=np.arange(SRDs.shape[0], dtype=np.float),
                    max_karcher_iterations=self.srvf_max_karcher_iterations,
                    lambda_value=self.srvf_lambda,
                    update_min=self.srvf_update_min)
                cluster_reg.run_registration_parallel()
                cluster_means[cluster_id] = cluster_reg.function_karcher_mean
                warping_functions[:,
                                  cluster_id_mask] = cluster_reg.mean_to_orig_warps

            # Scale the cluster Karcher means so the FR distance works
            scaled_cluster_means = {}
            for k, v in cluster_means.iteritems():
                scaled_cluster_means[k] = rescale_cluster_mean(v)

            # There are now k cluster means, which is closest for each SRD?
            # Also save its distance to its cluster's Karcher mean
            srd_cluster_assignments, srd_cluster_distances = get_closest_mean(
                SRDs, scaled_cluster_means)
            return srd_cluster_assignments, srd_cluster_distances, scaled_cluster_means, warping_functions

        # Iterate until assignments stabilize
        last_assignments = fcluster(linkage,
                                    self.n_modes,
                                    criterion="maxclust")
        stabilized = False
        n_iters = 0
        old_assignments = [last_assignments]
        old_means = []
        while not stabilized and n_iters < self.max_kmeans_iterations:
            logger.info("Iteration %d", n_iters)
            assignments, distances, cluster_means, warping_funcs = cluster_karcher_means(
                last_assignments)
            stabilized = np.all(last_assignments == assignments)
            last_assignments = assignments.copy()
            old_assignments.append(last_assignments)
            old_means.append(cluster_means)
            n_iters += 1

        # Finalize the clusters by aligning all the functions to the cluster mean
        # Iterate until assignments stabilize
        cluster_means = {}
        cluster_ids = np.unique(assignments)
        warping_functions = np.zeros_like(dzdt_functions_to_warp)
        self.registered_functions = np.zeros_like(self.dzdt_functions_to_warp)
        # Calculate a Karcher mean for each cluster
        for cluster_id in cluster_ids:
            cluster_id_mask = assignments == cluster_id
            cluster_funcs = self.dzdt_functions_to_warp.T[:, cluster_id_mask]

            # If there is only a single SRD in this cluster, it is the mean
            if cluster_id_mask.sum() == 1:
                cluster_means[cluster_id] = cluster_funcs
                continue

            # Run group registration to get Karcher mean
            cluster_reg = RegistrationProblem(
                cluster_funcs,
                sample_times=self.dzdt_karcher_mean_time,
                max_karcher_iterations=self.srvf_max_karcher_iterations,
                lambda_value=self.srvf_lambda,
                update_min=self.srvf_update_min)
            cluster_reg.run_registration_parallel()
            cluster_means[cluster_id] = cluster_reg.function_karcher_mean
            warping_functions[:,
                              cluster_id_mask] = cluster_reg.mean_to_orig_warps
            self.registered_functions[
                cluster_id_mask] = cluster_reg.registered_functions.T

        # Save the warps to the modes as the final warping functions
        self.dzdt_warping_functions = warping_functions.T \
                                      * (self.srvf_t_max - self.srvf_t_min) \
                                      + self.srvf_t_min

        # re-order the means
        cluster_ids = sorted(cluster_means.keys())
        final_assignments = np.zeros_like(assignments)
        final_modes = []
        for final_id, orig_id in enumerate(cluster_ids):
            final_assignments[assignments == orig_id] = final_id
            final_modes.append(cluster_means[orig_id].squeeze())

        self.mode_dzdt_karcher_means = np.row_stack(final_modes)
        self.mode_cluster_assignment = final_assignments
        self.all_beats_registered_to_mode = True

    def align_all_beats_to_initial(self):
        if not self.karcher_mean_calculated:
            logger.warn("Calculate Karcher mean first")
            return
        logger.info("Aligning all beats to the Karcher Mean")
        if self.interactive:
            progress = ProgressDialog(
                title="ICG Warping",
                min=0,
                max=len(self.physiodata.peak_times),
                show_time=True,
                message="Warping dZ/dt to Karcher Mean...")
            progress.open()

        template_func = self.dzdt_srvf_karcher_mean
        normed_template_func = template_func / np.linalg.norm(template_func)
        fy, fx = np.gradient(self.dzdt_functions_to_warp.T, 1, 1)
        srvf_functions = (fy / np.sqrt(np.abs(fy) + eps)).T

        gam = np.zeros(self.dzdt_functions_to_warp.shape, dtype=np.float)
        aligned_functions = self.dzdt_functions_to_warp.copy()
        logger.info("aligned_functions %d", id(aligned_functions))
        logger.info("self.dzdt_functions_to_warp %d",
                    id(self.dzdt_functions_to_warp))

        t = self.dzdt_karcher_mean_time
        for k in range(self.n_functions):
            q_c = srvf_functions[k] / np.linalg.norm(srvf_functions[k])
            G, T = dp(normed_template_func, t, q_c, t, t, t, self.srvf_lambda)
            gam0 = np.interp(self.dzdt_karcher_mean_time, T, G)
            gam[k] = (gam0 - gam0[0]) / (gam0[-1] - gam0[0])  # change scale
            aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t,
                                             self.dzdt_functions_to_warp[k])

            if self.interactive:

                # Update the image plot
                self.all_registration_plotdata.set_data(
                    "image_data", aligned_functions)

                # Update the registration plot
                self.single_registration_plotdata.set_data(
                    "registered_func", aligned_functions[k])
                self.single_plot.request_redraw()

                (cont, skip) = progress.update(k)

        self.registered_functions = aligned_functions.copy()

        self.dzdt_warping_functions = gam * (
                                        self.srvf_t_max - self.srvf_t_min) + \
                                        self.srvf_t_min

        if self.interactive:
            progress.update(k + 1)

        self.all_beats_registered_to_initial = True

    def _b_edit_modes_fired(self):
        self.mode_beat_train.edit_traits()

    def _point_plots_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """

        self.point_plotdata = ArrayPlotData(
            peak_times=self.physiodata.peak_times.flatten(),
            x_times=self.physiodata.x_indices - self.physiodata.dzdt_pre_peak,
            lvet=self.physiodata.lvet,
            pep=self.physiodata.pep)

        container = VPlotContainer(resizable="hv",
                                   bgcolor="lightgray",
                                   fill_padding=True,
                                   padding=10)

        for sig in ("pep", "x_times", "lvet"):
            temp_plt = Plot(self.point_plotdata)
            temp_plt.plot(("peak_times", sig), line_width=2)
            temp_plt.title = sig
            container.add(temp_plt)
        container.padding_top = 10

        return container

    mean_widgets = VGroup(VGroup(Item("karcher_plot",
                                      editor=ComponentEditor(),
                                      show_label=False),
                                 Item("registration_plot",
                                      editor=ComponentEditor(),
                                      show_label=False),
                                 show_labels=False),
                          Item("srvf_use_moving_ensembled",
                               label="Use Moving Ensembled dZ/dt"),
                          Item("bspline_before_warping",
                               label="B Spline smoothing"),
                          Item("srvf_t_min", label="Epoch Start Time"),
                          Item("srvf_t_max", label="Epoch End Time"),
                          Item("srvf_lambda", label="Lambda Value"),
                          Item("dzdt_num_inputs_to_group_warping",
                               label="Template uses N beats"),
                          Item("srvf_max_karcher_iterations",
                               label="Max Karcher Iterations"),
                          HGroup(
                              Item("b_calculate_karcher_mean",
                                   label="Step 1:",
                                   enabled_when="dirty"),
                              Item("b_align_all_beats",
                                   label="Step 2:",
                                   enabled_when="karcher_mean_calculated")),
                          label="Initial Karcher Mean")

    mode_widgets = VGroup(
        Item("n_modes", label="Number of Modes/Clusters"),
        Item("max_kmeans_iterations"),
        Item("b_find_modes",
             show_label=False,
             enabled_when="all_beats_registered_to_initial"),
        Item("b_edit_modes",
             show_label=False,
             enabled_when="all_beats_registered_to_mode"))

    traits_view = MEAPView(HSplit(mean_widgets, mode_widgets),
                           resizable=True,
                           win_title="ICG Warping Tools",
                           width=800,
                           height=700,
                           buttons=[OKButton, CancelButton])
Пример #11
0
    def _plot_default(self):

        plot = Plot(self.plotdata)

        if self.signal in ("tpr", "co", "sv"):
            rc_signal = "resp_corrected_" + self.signal
            if rc_signal in self.plotdata.arrays and self.plotdata.arrays[
                    rc_signal].size > 0:
                # Plot the resp-uncorrected version underneath
                plot.plot(("peak_times", self.signal),
                          type="line",
                          color="purple")
                signal = rc_signal
            signal = self.signal
        elif self.signal == "hr":
            plot.plot(("peak_times", self.signal), type="line", color="purple")
            signal = "mea_hr"
        else:
            signal = self.signal

        # Create the plot
        plot.plot(("peak_times", signal), type="line", color="blue")
        plot.plot(("peak_times", signal, "beat_type"),
                  type="cmap_scatter",
                  color_mapper=jet,
                  name="my_plot",
                  marker="circle",
                  border_visible=False,
                  outline_color="transparent",
                  bg_color="white",
                  index_sort="ascending",
                  marker_size=3,
                  fill_alpha=0.8)

        # Tweak some of the plot properties
        plot.title = self.signal  #"Scatter Plot With Lasso Selection"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.line_width = 1
        plot.padding = 20

        # Right now, some of the tools are a little invasive, and we need the
        # actual ScatterPlot object to give to them
        my_plot = plot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.selection = BeatSelection(component=my_plot)
        my_plot.active_tool = self.selection
        selection_overlay = RangeSelectionOverlay(component=my_plot)
        my_plot.overlays.append(selection_overlay)

        # Set up the trait handler for the selection
        self.selection.on_trait_change(self._selection_changed,
                                       'selection_completed')

        return plot