Esempio n. 1
0
class MEAPhysioExperiment(PhysioExperiment):
    events = List(Instance(ExpEvent))
    physio_files = List(Instance(PhysioFileMEAAnalysis))

    def _b_save_fired(self):
        raise NotImplementedError()

    def file_constructur(self):
        return PhysioFileMEAAnalysis

    def event_constructor(self):
        return MovingEnsembler

    def run(self):
        # read the excel file
        self.open_xls()
        # organize all the data from it
        if not self.events_processed:
            self._process_events()
        # Run the pipeline on each acq_file
        self.extract_mea()

    def extract_mea(self):
        stacks = {}
        for evt in self.conditions:
            stacks[evt] = {}

        for pf in self.physio_files:
            pf.extract_mea()

    def save_mea_document(self, output_path=""):
        #self.ensemble_average()
        if not output_path:
            output_path = self.output_xls

    traits_view = MEAPView(HSplit(
        Group(
            Group(Item("input_path", label="Input XLS"),
                  Item("output_directory", label="Output Directory"),
                  Item("b_run", show_label=False),
                  Item("b_save", show_label=False),
                  orientation="horizontal"), ), ),
                           resizable=True,
                           win_title="MEA Analysis")
Esempio n. 2
0
class MEAPGreeter(HasTraits):
    preproc = Button("Preprocess")
    analyze = Button("Analyze")
    configure = Button("Configure MEAP")
    batch_spreadsheet = Button("Create Batch Spreadsheet")
    register_dZdt = Button("Batch Register dZ/dt")    
    
    def _preproc_fired(self):
        from meap.preprocessing import PreprocessingPipeline
        preproc=PreprocessingPipeline()
        preproc.configure_traits()
        
    def _analyze_fired(self):
        from meap.physio_analysis import PhysioExperiment
        pe = PhysioExperiment()
        pe.configure_traits()
        
    def _register_dZdt_fired(self):
        from meap.batch_warp_dzdt import BatchGroupRegisterDZDT
        batch = BatchGroupRegisterDZDT()
        batch.configure_traits()
        
    def _batch_spreadsheet_fired(self):
        from meap.make_batch_spreadsheet import BatchFileTool
        bft = BatchFileTool()
        bft.configure_traits()
        
    def _configure_fired(self):
        print "Not implemented yet!"
        
    traits_view=MEAPView(
        VGroup(
            Item("preproc"),
            Item("analyze"),
            Item("batch_spreadsheet"),
            Item("register_dZdt"),
            Item("configure"),
            show_labels=False
        )
    )
Esempio n. 3
0
class PreprocessingPipeline(HasTraits):
    # Holds the set of multiple acq files
    group_excel_file = File
    # Which columns contain numeric data?
    # Holds the unique eveny types for this experiment
    physio_files = List(Instance(PhysioFilePreprocessor))
    # Interactively process the data?
    interactive = Bool(True)
    header = List

    @on_trait_change("group_excel_file")
    def open_xls(self, fpath=""):
        if not fpath:
            fpath = self.group_excel_file
        logger.info("Loading " + fpath)
        wb = xlrd.open_workbook(fpath)
        sheet = wb.sheet_by_index(0)
        header = [str(item.value) for item in sheet.row(0)]

        for rnum in range(1, sheet.nrows):
            rw = sheet.row(rnum)
            if not len(rw) == len(header):
                #raise ValueError("not enough information in row %d" % (rnum + 1))
                continue
            #convert to numbers and strings
            specs = {}
            for item, hdr in zip(rw, header):
                if hdr == "in_mri":
                    specs[hdr] = bool(item.value)
                else:
                    try:
                        specs[hdr] = float(item.value)
                    except Exception:
                        specs[hdr] = str(item.value)

            self.physio_files.append(
                PhysioFilePreprocessor(interactive=self.interactive,
                                       file=specs['file'],
                                       outfile=specs.get('outfile', ''),
                                       specs=specs))
        self.header = header

    traits_view = MEAPView(VGroup(
        VGroup("group_excel_file", show_border=True, label="Excel Files"),
        HGroup(Item("physio_files", editor=pipeline_table, show_label=False)),
    ),
                           resizable=True,
                           width=300,
                           height=500)
Esempio n. 4
0
class CensorRegion(HasTraits):
    start_time = Float(-1.)
    end_time = Float(-1.)
    viz = Instance(RangeSelectionOverlay, transient=True)
    metadata_name = Str
    plot = Instance(LinePlot, transient=True)

    traits_view = MEAPView(Item("start_time"), Item("end_time"))

    def set_limits(self, start, end):
        """
        """
        self.plot.index.metadata[self.metadata_name] = start, end

    @on_trait_change("plot.index.metadata")
    def metadata_chaned(self):
        try:
            if self.plot.index.metadata[self.metadata_name] is None:
                return
            st, end = self.plot.index.metadata[self.metadata_name]
        except KeyError:
            return
        self.start_time = st
        self.end_time = end

    def _viz_default(self):
        self.plot.active_tool = CensorSelection(
            self.plot,
            selection_mode="append",
            append_key=KeySpec("control"),
            left_button_selects=True,
            metadata_name=self.metadata_name)
        rso = RangeSelectionOverlay(component=self.plot,
                                    metadata_name=self.metadata_name)
        self.plot.overlays.append(rso)
        return rso
Esempio n. 5
0
class PhysioFilePreprocessor(HasTraits):
    # Holds the data from
    physiodata = Instance(PhysioData)
    interactive = Bool(True)
    pipeline = Instance(MEAPPipeline)
    file = DelegatesTo("pipeline")
    outfile = DelegatesTo("pipeline")
    specs = Dict
    importer_kwargs = DelegatesTo("pipeline")
    mea_saved = Property(Bool, depends_on="file,outfile")

    def _outfile_default(self):
        if self.file.endswith(".acq"):
            return self.file[:-4] + ".mea"
        return self.file + ".mea"

    def __init__(self, **traits):
        super(PhysioFilePreprocessor, self).__init__(**traits)

        # Some of the data contained in the excel sheet should get
        # attached to the physiodata file once it is loaded. These
        # go into the physiodata_kwargs attribute
        pd_traits = PhysioData().editable_traits()
        for spec, spec_value in self.specs.iteritems():
            if "subject_" + spec in pd_traits:
                self.importer_kwargs["subject_" + spec] = spec_value

    def _pipeline_default(self):
        pipeline = MEAPPipeline()
        return pipeline

    def _get_mea_saved(self):
        if self.file.endswith("mea"):
            return True
        return os.path.exists(self.pipeline.outfile) or \
               os.path.exists(self.pipeline.outfile + ".mat")

    @on_trait_change("pipeline.saved")
    def pipeline_saved(self):
        """A dumb hack"""
        old_outfile = self.outfile
        self.outfile = "asdf"
        self.outfile = old_outfile

    pipeline_view = MEAPView(
        Group(Item("pipeline", style="custom"), show_labels=False))
Esempio n. 6
0
 def default_traits_view(self):
     plots = [ Item(sig+"_ts", style="custom", show_label=False) for sig in \
                     sorted(self.physiodata.contents) if not sig.startswith("resp_corr")]
     if "ecg2" in self.physiodata.contents:
         widgets = VGroup(
             Item("qrs_source_signal", label="ECG to use"),
             Item("start_time"),
             Item("window_size"),
         )
     else:
         widgets = VGroup(
             Item("start_time"),
             Item("window_size"),
         )
     return MEAPView(VGroup(VGroup(*plots),
                            Group(widgets, orientation="horizontal")),
                     width=1000,
                     height=600,
                     resizable=True,
                     win_title="Aqcknowledge Data",
                     buttons=[OKButton, CancelButton],
                     handler=DataPlotHandler())
Esempio n. 7
0
class RespirationProcessor(HasTraits):
    # Parameters
    resp_polort = DelegatesTo("physiodata")
    resp_high_freq_cutoff = DelegatesTo("physiodata")

    time = DelegatesTo("physiodata", "processed_respiration_time")

    # Data object and arrays to save
    physiodata = Instance(PhysioData)
    respiration_cycle = DelegatesTo("physiodata")
    respiration_amount = DelegatesTo("physiodata")

    # For visualization
    plots = Instance(VPlotContainer, transient=True)

    # Respiration plot
    resp_plot = Instance(Plot, transient=True)
    resp_plot_data = Instance(ArrayPlotData, transient=True)
    resp_signal = Array()  # Could be filtered z0 or resp belt
    raw_resp_signal = Array()  # Could be filtered z0 or resp belt
    resp_signal = Property(
        Array,
        depends_on="physiodata.resp_polort,physiodata.resp_high_freq_cutoff")
    polyfilt_resp = Array()
    lpfilt_resp = DelegatesTo("physiodata", "processed_respiration_data")
    resp_inhale_begin_times = DelegatesTo("physiodata")
    resp_inhale_begin_values = Array
    resp_exhale_begin_times = DelegatesTo("physiodata")
    resp_exhale_begin_values = Array

    z0_plot = Instance(Plot, transient=True)
    z0_plot_data = Instance(ArrayPlotData, transient=True)
    raw_z0_signal = Array()
    resp_corrected_z0 = DelegatesTo("physiodata")

    dzdt_plot = Instance(Plot, transient=True)
    dzdt_plot_data = Instance(ArrayPlotData, transient=True)
    raw_dzdt_signal = Array()
    resp_corrected_dzdt = DelegatesTo("physiodata")

    start_time = Range(low=0, high=10000.0, initial=0.0)
    window_size = Range(low=1.0, high=300.0, initial=30.0)

    state = Enum("unusable", "z0", "resp", "z0_resp")
    dirty = Bool(True)

    b_process = Button(label="Process Respiration")

    def __init__(self, **traits):
        super(RespirationProcessor, self).__init__(**traits)
        resp_inc = "respiration" in self.physiodata.contents
        z0_inc = "z0" in self.physiodata.contents
        if z0_inc and resp_inc:
            self.state = "z0_resp"
            messagebox("Using both respiration belt and ICG")
            z0_signal = self.physiodata.z0_data
            dzdt_signal = self.physiodata.dzdt_data
            resp_signal = self.physiodata.respiration_data
        elif resp_inc and not z0_inc:
            self.state = "resp"
            messagebox("Only respiration belt data will be used.")
            resp_signal = self.physiodata.resp_data
            z0_signal = None
            dzdt_signal = None
        elif z0_inc and not resp_inc:
            self.state = "z0"
            messagebox("Using only z0 channel to estimate respiration")
            resp_signal = self.physiodata.z0_data
            z0_signal = self.physiodata.z0_data
            dzdt_signal = self.physiodata.dzdt_data
            self.resp_polort = 1
        else:
            self.state = "unusable"
            messagebox("No respiration belt or z0 channels found")

        # Establish  the maximum shared length of resp_inhale_begin_values
        signals = [sig for sig in (resp_signal, z0_signal,dzdt_signal) if \
                sig is not None]
        if len(signals) > 1:
            minlen = min([len(sig) for sig in signals])
        else:
            minlen = len(signals[0])

        # Establish a time array
        if resp_inc:
            self.time = TimeSeries(physiodata=self.physiodata,
                                   contains="respiration").time[:minlen]
        else:
            self.time = TimeSeries(physiodata=self.physiodata,
                                   contains="z0").time[:minlen]

        # Get out the final signals
        if z0_inc:
            self.raw_resp_signal = z0_signal[:minlen].copy()
            self.raw_resp_signal[:50] = self.raw_resp_signal.mean()
            self.raw_z0_signal = z0_signal[:minlen].copy()
            self.raw_z0_signal[:50] = self.raw_z0_signal.mean()
            self.raw_dzdt_signal = dzdt_signal[:minlen].copy()
            self.raw_dzdt_signal[:50] = self.raw_dzdt_signal.mean()
        if resp_inc:
            self.raw_resp_signal = resp_signal[:minlen].copy()
        # if data already exists, it can't be dirty
        if self.physiodata.resp_exhale_begin_times.size > 0: self.dirty = False
        self.on_trait_change(self._parameter_changed,
                             "resp_polort,resp_high_freq_cutoff")

    @cached_property
    def _get_resp_signal(self):
        resp_signal = winsorize(self.raw_resp_signal)
        return (resp_signal - resp_signal.mean()) / resp_signal.std()

    def _parameter_changed(self):
        self.dirty = True

    def _b_process_fired(self):
        self.process()

    def process(self):
        """
        processes the respiration timeseries
        """
        sampling_rate = 1. / (self.time[1] - self.time[0])
        # Respiration belts can lose tension over time.
        # This removes linear trends
        if self.resp_polort > 0:
            pfit = legendre_detrend(self.resp_signal, self.resp_polort)
            if not pfit.shape == self.time.shape:
                messagebox("Legendre detrend failed")
                return
            self.polyfilt_resp = pfit
        else:
            self.polyfilt_resp = self.resp_signal

        lpd = lowpass(self.polyfilt_resp, self.resp_high_freq_cutoff,
                      sampling_rate)
        if not lpd.shape == self.time.shape:
            messagebox("lowpass filter failed")
            return

        self.lpfilt_resp = (lpd - lpd.mean()) / lpd.std()

        resp_inhale_begin_indices, resp_exhale_begin_indices = find_peaks(
            self.lpfilt_resp, maxima=True, minima=True)
        self.resp_inhale_begin_times = self.time[resp_inhale_begin_indices]
        self.resp_exhale_begin_times = self.time[resp_exhale_begin_indices]
        self.resp_corrected_z0 = regress_out(self.raw_z0_signal,
                                             self.lpfilt_resp)
        self.resp_corrected_dzdt = regress_out(self.raw_dzdt_signal,
                                               self.lpfilt_resp)

        # update the scatterplot if we're interactive
        if self.resp_plot_data is not None:
            self.resp_plot_data.set_data("inhale_times",
                                         self.resp_inhale_begin_times)
            self.resp_plot_data.set_data(
                "inhale_values", self.lpfilt_resp[resp_inhale_begin_indices])
            self.resp_plot_data.set_data("exhale_times",
                                         self.resp_exhale_begin_times)
            self.resp_plot_data.set_data(
                "exhale_values", self.lpfilt_resp[resp_exhale_begin_indices])
            self.resp_plot_data.set_data("lpfilt_resp", self.lpfilt_resp)
            self.dzdt_plot_data.set_data("cleaned_data",
                                         self.resp_corrected_dzdt)
            self.z0_plot_data.set_data("cleaned_data", self.resp_corrected_z0)
        else:
            print "plot data is none"

        # Update the respiration cycles cached_pro
        times = np.concatenate([
            np.zeros(1), self.resp_exhale_begin_times,
            self.resp_inhale_begin_times, self.time[-1, np.newaxis]
        ])
        vals = np.concatenate([
            np.zeros(1),
            np.zeros_like(resp_exhale_begin_indices),
            0.5 * np.ones_like(resp_inhale_begin_indices),
            np.zeros(1)
        ])
        srt = np.argsort(times)
        terp = interp1d(times[srt], vals[srt])
        self.respiration_cycle = terp(self.time)
        self.respiration_amount = np.abs(
            np.ediff1d(self.respiration_cycle, to_begin=0))

    def _resp_plot_default(self):
        # Create plotting components
        self.resp_plot_data = ArrayPlotData(
            time=self.time,
            inhale_times=self.resp_inhale_begin_times,
            inhale_values=self.resp_inhale_begin_values,
            exhale_times=self.resp_exhale_begin_times,
            exhale_values=self.resp_exhale_begin_values,
            filtered_resp=self.resp_signal,
            lpfilt_resp=self.lpfilt_resp)
        plot = Plot(self.resp_plot_data)
        plot.plot(("time", "filtered_resp"), color="blue", line_width=1)
        plot.plot(("time", "lpfilt_resp"), color="green", line_width=1)
        # Plot the inhalation peaks
        plot.plot(("inhale_times", "inhale_values"),
                  type="scatter",
                  marker="square")
        plot.plot(("exhale_times", "exhale_values"),
                  type="scatter",
                  marker="circle")
        plot.title = "Respiration"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _z0_plot_default(self):
        self.z0_plot_data = ArrayPlotData(time=self.time,
                                          raw_data=self.raw_z0_signal,
                                          cleaned_data=self.resp_corrected_z0)
        plot = Plot(self.z0_plot_data)
        plot.plot(("time", "raw_data"), color="blue", line_width=1)
        plot.plot(("time", "cleaned_data"), color="green", line_width=1)
        plot.title = "z0"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _dzdt_plot_default(self):
        """
        Creates a plot of the ecg_ts data and the signals derived during
        the Pan Tomkins algorithm
        """
        # Create plotting components
        self.dzdt_plot_data = ArrayPlotData(
            time=self.time,
            raw_data=self.raw_dzdt_signal,
            cleaned_data=self.resp_corrected_dzdt)
        plot = Plot(self.dzdt_plot_data)
        plot.plot(("time", "raw_data"), color="blue", line_width=1)
        plot.plot(("time", "cleaned_data"), color="green", line_width=1)
        plot.title = "dZ/dt"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _plots_default(self):
        plots_to_include = ()
        if self.state in ("z0_resp", "z0"):
            self.index_range = self.resp_plot.index_range
            self.dzdt_plot.index_range = self.index_range
            self.z0_plot.index_range = self.index_range
            plots_to_include = [self.resp_plot, self.z0_plot, self.dzdt_plot]
        elif self.state == "resp":
            self.index_range = self.resp_plot.index_range
            plots_to_include = [self.resp_plot]
        return VPlotContainer(*plots_to_include)

    @on_trait_change("window_size,start_time")
    def update_plot_range(self):
        self.resp_plot.index_range.high = self.start_time + self.window_size
        self.resp_plot.index_range.low = self.start_time

    proc_params_group = Group(Group(
        Item("resp_polort"),
        Item("resp_high_freq_cutoff"),
        Item("b_process",
             show_label=False,
             enabled_when="state != unusable and dirty"),
        label="Resp processing options",
        show_border=True,
        orientation="vertical",
        springy=True),
                              orientation="horizontal")

    plot_group = Group(
        Group(Item("plots", editor=ComponentEditor(), width=800, height=700),
              show_labels=False), Item("start_time"), Item("window_size"))
    traits_view = MEAPView(
        VSplit(
            plot_group,
            proc_params_group,
        ),
        resizable=True,
        win_title="Process Respiration Data",
    )
Esempio n. 8
0
class GroupRegisterDZDT(HasTraits):
    physiodata = Instance(PhysioData)

    global_ensemble = Instance(GlobalEnsembleAveragedHeartBeat)
    srvf_lambda = DelegatesTo("physiodata")
    srvf_max_karcher_iterations = DelegatesTo("physiodata")
    srvf_update_min = DelegatesTo("physiodata")
    srvf_karcher_mean_subset_size = DelegatesTo("physiodata")
    dzdt_srvf_karcher_mean = DelegatesTo("physiodata")
    dzdt_karcher_mean = DelegatesTo("physiodata")
    dzdt_warping_functions = DelegatesTo("physiodata")
    srvf_use_moving_ensembled = DelegatesTo("physiodata")
    bspline_before_warping = DelegatesTo("physiodata")
    dzdt_functions_to_warp = DelegatesTo("physiodata")

    # Configures the slice of time to be registered
    srvf_t_min = DelegatesTo("physiodata")
    srvf_t_max = DelegatesTo("physiodata")
    dzdt_karcher_mean_time = DelegatesTo("physiodata")
    dzdt_mask = Array

    # Holds indices of beats used to calculate initial Karcher mean
    dzdt_karcher_mean_inputs = DelegatesTo("physiodata")
    dzdt_karcher_mean_over_iterations = DelegatesTo("physiodata")
    dzdt_num_inputs_to_group_warping = DelegatesTo("physiodata")
    srvf_iteration_distances = DelegatesTo("physiodata")
    srvf_iteration_energy = DelegatesTo("physiodata")

    # For calculating multiple modes
    all_beats_registered_to_initial = Bool(False)
    all_beats_registered_to_mode = Bool(False)
    n_modes = DelegatesTo("physiodata")
    max_kmeans_iterations = DelegatesTo("physiodata")
    mode_dzdt_karcher_means = DelegatesTo("physiodata")
    mode_cluster_assignment = DelegatesTo("physiodata")
    mode_dzdt_srvf_karcher_means = DelegatesTo("physiodata")

    # graphics items
    karcher_plot = Instance(Plot, transient=True)
    registration_plot = Instance(HPlotContainer, transient=True)
    karcher_plotdata = Instance(ArrayPlotData, transient=True)

    # Buttons
    b_calculate_karcher_mean = Button(label="Calculate Karcher Mean")
    b_align_all_beats = Button(label="Warp all")
    b_find_modes = Button(label="Discover Modes")
    b_edit_modes = Button(label="Score Modes")
    interactive = Bool(False)

    # Holds the karcher modes
    mode_beat_train = Instance(ModeKarcherBeatTrain)
    edit_listening = Bool(False,
                          desc="If true, update_plots is called"
                          " when a beat gets hand labeled")

    def __init__(self, **traits):
        super(GroupRegisterDZDT, self).__init__(**traits)

        # Set the initial path to whatever's in the physiodata file
        logger.info("Initializing dZdt registration")

        # Is there already a Karcher mean in the physiodata?
        self.karcher_mean_calculated = self.dzdt_srvf_karcher_mean.size > 0 \
                                and self.dzdt_karcher_mean.size > 0

        # Process dZdt data before srvf analysis
        self._update_original_functions()

        self.n_functions = self.dzdt_functions_to_warp.shape[0]
        self.n_samples = self.dzdt_functions_to_warp.shape[1]
        self._init_warps()

        self.select_new_samples()

    def _mode_beat_train_default(self):
        logger.info("creating default mea_beat_train")
        assert self.physiodata is not None
        mkbt = ModeKarcherBeatTrain(physiodata=self.physiodata)
        return mkbt

    def _init_warps(self):
        """ For loading warps from mea.mat """
        if self.dzdt_warping_functions.shape[0] == \
                                self.dzdt_functions_to_warp.shape[0]:
            self.all_beats_registered_to_mode = True
            self._forward_warp_beats()

    def _global_ensemble_default(self):
        return GlobalEnsembleAveragedHeartBeat(physiodata=self.physiodata)

    def _karcher_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        self.interactive = True
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.karcher_plotdata = ArrayPlotData(
            time=self.global_ensemble.dzdt_time,
            unregistered_mean=self.global_ensemble.dzdt_signal,
            karcher_mean=karcher_mean,
            karcher_time=self.dzdt_karcher_mean_time)
        karcher_plot = Plot(self.karcher_plotdata)
        karcher_plot.plot(("time", "unregistered_mean"),
                          line_width=1,
                          color="lightblue")
        line_plot = karcher_plot.plot(("karcher_time", "karcher_mean"),
                                      line_width=3,
                                      color="blue")[0]
        # Create an overlay tool
        karcher_plot.datasources['karcher_time'].sort_order = "ascending"

        return karcher_plot

    def _registration_plot_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """
        unreg_mean = self.global_ensemble.dzdt_signal
        # Temporarily fill in the karcher mean
        if not self.karcher_mean_calculated:
            karcher_mean = unreg_mean[self.dzdt_mask]
        else:
            karcher_mean = self.dzdt_karcher_mean

        self.single_registration_plotdata = ArrayPlotData(
            karcher_time=self.dzdt_karcher_mean_time,
            karcher_mean=karcher_mean,
            registered_func=karcher_mean)

        self.single_plot = Plot(self.single_registration_plotdata)
        self.single_plot.plot(("karcher_time", "karcher_mean"),
                              line_width=2,
                              color="blue")
        self.single_plot.plot(("karcher_time", "registered_func"),
                              line_width=2,
                              color="maroon")


        if self.all_beats_registered_to_initial or \
           self.all_beats_registered_to_mode:
            image_data = self.registered_functions.copy()
        else:
            image_data = self.dzdt_functions_to_warp.copy()

        # Create a plot of all the functions registered or not
        self.all_registration_plotdata = ArrayPlotData(image_data=image_data)
        self.all_plot = Plot(self.all_registration_plotdata)
        self.all_plot.img_plot("image_data", colormap=jet)

        return HPlotContainer(self.single_plot, self.all_plot)

    def _forward_warp_beats(self):
        """ Create registered beats to plot, since they're not stored """
        pass
        logger.info("Applying warps to functions for plotting")
        # Re-create gam
        gam = self.dzdt_warping_functions - self.srvf_t_min
        gam = gam / (self.srvf_t_max - self.srvf_t_min)

        aligned_functions = np.zeros_like(self.dzdt_functions_to_warp)
        t = self.dzdt_karcher_mean_time
        for k in range(self.n_functions):
            aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t,
                                             self.dzdt_functions_to_warp[k])

        self.registered_functions = aligned_functions

    @on_trait_change("dzdt_num_inputs_to_group_warping")
    def select_new_samples(self):
        nbeats = self.dzdt_functions_to_warp.shape[0]
        nsamps = min(self.dzdt_num_inputs_to_group_warping, nbeats)
        self.dzdt_karcher_mean_inputs = np.random.choice(nbeats,
                                                         size=nsamps,
                                                         replace=False)

    @on_trait_change(
        ("physiodata.srvf_lambda, "
         "physiodata.srvf_update_min, physiodata.srvf_use_moving_ensembled, "
         "physiodata.srvf_karcher_mean_subset_size, "
         "physiodata.bspline_before_warping"))
    def params_edited(self):
        self.dirty = True
        self.karcher_mean_calculated = False
        self.all_beats_registered_to_initial = False
        self.all_beats_registered_to_mode = False
        self._update_original_functions()

    def _update_original_functions(self):
        logger.info("updating time slice and functions to register")

        # Get the time relative to R
        dzdt_time = np.arange(self.physiodata.dzdt_matrix.shape[1],dtype=np.float) \
                    - self.physiodata.dzdt_pre_peak
        self.dzdt_mask = (dzdt_time >= self.srvf_t_min) * (dzdt_time <=
                                                           self.srvf_t_max)
        srvf_time = dzdt_time[self.dzdt_mask]
        self.dzdt_karcher_mean_time = srvf_time

        # Extract corresponding data
        self.dzdt_functions_to_warp = self.physiodata.mea_dzdt_matrix[
                                    : , self.dzdt_mask] if \
                self.srvf_use_moving_ensembled else self.physiodata.dzdt_matrix[
                                    : , self.dzdt_mask]

        if self.bspline_before_warping:
            logger.info("Smoothing inputs with B-Splines")
            self.dzdt_functions_to_warp = np.row_stack([ UnivariateSpline(
                self.dzdt_karcher_mean_time, func, s=0.05)(self.dzdt_karcher_mean_time) \
                for func in self.dzdt_functions_to_warp]
            )
        if self.interactive:
            self.all_registration_plotdata.set_data(
                "image_data", self.dzdt_functions_to_warp.copy())
            self.all_plot.request_redraw()

    def _b_calculate_karcher_mean_fired(self):
        self.calculate_karcher_mean()

    def calculate_karcher_mean(self):
        """
        Calculates an initial Karcher Mean.
        """
        reg_prob = RegistrationProblem(
            self.dzdt_functions_to_warp[self.dzdt_karcher_mean_inputs].T,
            sample_times=self.dzdt_karcher_mean_time,
            max_karcher_iterations=self.srvf_max_karcher_iterations,
            lambda_value=self.srvf_lambda,
            update_min=self.srvf_update_min)
        reg_prob.run_registration_parallel()
        reg_problem = reg_prob
        self.dzdt_karcher_mean = reg_problem.function_karcher_mean
        self.dzdt_srvf_karcher_mean = reg_problem.srvf_karcher_mean
        self.karcher_mean_calculated = True

        # Update plots if this is interactive
        if self.interactive:
            self.karcher_plotdata.set_data("karcher_mean",
                                           self.dzdt_karcher_mean)
            self.karcher_plot.request_redraw()
            self.single_registration_plotdata.set_data("karcher_mean",
                                                       self.dzdt_karcher_mean)
            self.single_plot.request_redraw()
        self.rp = reg_problem

    def _b_align_all_beats_fired(self):
        self.align_all_beats_to_initial()

    def _b_find_modes_fired(self):
        self.detect_modes()

    def detect_modes(self):
        """
        Uses the SRD-based clustering method described in Kurtek 2017
        """
        if not self.karcher_mean_calculated:
            fail("Must calculate an initial Karcher mean first")
            return
        if not self.all_beats_registered_to_initial:
            fail("All beats must be registered to the initial Karcher mean")
            return

        # Calculate the SRDs
        dzdt_functions_to_warp = self.dzdt_functions_to_warp.T
        warps = self.dzdt_warping_functions.T
        wmax = warps.max()
        wmin = warps.min()
        warps = (warps - wmin) / (wmax - wmin)
        densities = np.diff(warps, axis=0)
        SRDs = np.sqrt(densities)
        # pairwise distances
        srd_pairwise = pairwise_distances(SRDs.T, metric=fisher_rao_dist)
        tri = srd_pairwise[np.triu_indices_from(srd_pairwise, 1)]
        linkage = complete(tri)

        # Performs an iteration of k-means
        def cluster_karcher_means(initial_assignments):
            cluster_means = {}
            cluster_ids = np.unique(initial_assignments).tolist()
            warping_functions = np.zeros_like(SRDs)

            # Calculate a Karcher mean for each cluster
            for cluster_id in cluster_ids:
                print "Cluster ID:", cluster_id
                cluster_id_mask = initial_assignments == cluster_id
                cluster_srds = SRDs[:, cluster_id_mask]

                # If there is only a single SRD in this cluster, it is the mean
                if cluster_id_mask.sum() == 1:
                    cluster_means[cluster_id] = cluster_srds
                    continue

                # Run group registration to get Karcher mean
                cluster_reg = RegistrationProblem(
                    cluster_srds,
                    sample_times=np.arange(SRDs.shape[0], dtype=np.float),
                    max_karcher_iterations=self.srvf_max_karcher_iterations,
                    lambda_value=self.srvf_lambda,
                    update_min=self.srvf_update_min)
                cluster_reg.run_registration_parallel()
                cluster_means[cluster_id] = cluster_reg.function_karcher_mean
                warping_functions[:,
                                  cluster_id_mask] = cluster_reg.mean_to_orig_warps

            # Scale the cluster Karcher means so the FR distance works
            scaled_cluster_means = {}
            for k, v in cluster_means.iteritems():
                scaled_cluster_means[k] = rescale_cluster_mean(v)

            # There are now k cluster means, which is closest for each SRD?
            # Also save its distance to its cluster's Karcher mean
            srd_cluster_assignments, srd_cluster_distances = get_closest_mean(
                SRDs, scaled_cluster_means)
            return srd_cluster_assignments, srd_cluster_distances, scaled_cluster_means, warping_functions

        # Iterate until assignments stabilize
        last_assignments = fcluster(linkage,
                                    self.n_modes,
                                    criterion="maxclust")
        stabilized = False
        n_iters = 0
        old_assignments = [last_assignments]
        old_means = []
        while not stabilized and n_iters < self.max_kmeans_iterations:
            logger.info("Iteration %d", n_iters)
            assignments, distances, cluster_means, warping_funcs = cluster_karcher_means(
                last_assignments)
            stabilized = np.all(last_assignments == assignments)
            last_assignments = assignments.copy()
            old_assignments.append(last_assignments)
            old_means.append(cluster_means)
            n_iters += 1

        # Finalize the clusters by aligning all the functions to the cluster mean
        # Iterate until assignments stabilize
        cluster_means = {}
        cluster_ids = np.unique(assignments)
        warping_functions = np.zeros_like(dzdt_functions_to_warp)
        self.registered_functions = np.zeros_like(self.dzdt_functions_to_warp)
        # Calculate a Karcher mean for each cluster
        for cluster_id in cluster_ids:
            cluster_id_mask = assignments == cluster_id
            cluster_funcs = self.dzdt_functions_to_warp.T[:, cluster_id_mask]

            # If there is only a single SRD in this cluster, it is the mean
            if cluster_id_mask.sum() == 1:
                cluster_means[cluster_id] = cluster_funcs
                continue

            # Run group registration to get Karcher mean
            cluster_reg = RegistrationProblem(
                cluster_funcs,
                sample_times=self.dzdt_karcher_mean_time,
                max_karcher_iterations=self.srvf_max_karcher_iterations,
                lambda_value=self.srvf_lambda,
                update_min=self.srvf_update_min)
            cluster_reg.run_registration_parallel()
            cluster_means[cluster_id] = cluster_reg.function_karcher_mean
            warping_functions[:,
                              cluster_id_mask] = cluster_reg.mean_to_orig_warps
            self.registered_functions[
                cluster_id_mask] = cluster_reg.registered_functions.T

        # Save the warps to the modes as the final warping functions
        self.dzdt_warping_functions = warping_functions.T \
                                      * (self.srvf_t_max - self.srvf_t_min) \
                                      + self.srvf_t_min

        # re-order the means
        cluster_ids = sorted(cluster_means.keys())
        final_assignments = np.zeros_like(assignments)
        final_modes = []
        for final_id, orig_id in enumerate(cluster_ids):
            final_assignments[assignments == orig_id] = final_id
            final_modes.append(cluster_means[orig_id].squeeze())

        self.mode_dzdt_karcher_means = np.row_stack(final_modes)
        self.mode_cluster_assignment = final_assignments
        self.all_beats_registered_to_mode = True

    def align_all_beats_to_initial(self):
        if not self.karcher_mean_calculated:
            logger.warn("Calculate Karcher mean first")
            return
        logger.info("Aligning all beats to the Karcher Mean")
        if self.interactive:
            progress = ProgressDialog(
                title="ICG Warping",
                min=0,
                max=len(self.physiodata.peak_times),
                show_time=True,
                message="Warping dZ/dt to Karcher Mean...")
            progress.open()

        template_func = self.dzdt_srvf_karcher_mean
        normed_template_func = template_func / np.linalg.norm(template_func)
        fy, fx = np.gradient(self.dzdt_functions_to_warp.T, 1, 1)
        srvf_functions = (fy / np.sqrt(np.abs(fy) + eps)).T

        gam = np.zeros(self.dzdt_functions_to_warp.shape, dtype=np.float)
        aligned_functions = self.dzdt_functions_to_warp.copy()
        logger.info("aligned_functions %d", id(aligned_functions))
        logger.info("self.dzdt_functions_to_warp %d",
                    id(self.dzdt_functions_to_warp))

        t = self.dzdt_karcher_mean_time
        for k in range(self.n_functions):
            q_c = srvf_functions[k] / np.linalg.norm(srvf_functions[k])
            G, T = dp(normed_template_func, t, q_c, t, t, t, self.srvf_lambda)
            gam0 = np.interp(self.dzdt_karcher_mean_time, T, G)
            gam[k] = (gam0 - gam0[0]) / (gam0[-1] - gam0[0])  # change scale
            aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t,
                                             self.dzdt_functions_to_warp[k])

            if self.interactive:

                # Update the image plot
                self.all_registration_plotdata.set_data(
                    "image_data", aligned_functions)

                # Update the registration plot
                self.single_registration_plotdata.set_data(
                    "registered_func", aligned_functions[k])
                self.single_plot.request_redraw()

                (cont, skip) = progress.update(k)

        self.registered_functions = aligned_functions.copy()

        self.dzdt_warping_functions = gam * (
                                        self.srvf_t_max - self.srvf_t_min) + \
                                        self.srvf_t_min

        if self.interactive:
            progress.update(k + 1)

        self.all_beats_registered_to_initial = True

    def _b_edit_modes_fired(self):
        self.mode_beat_train.edit_traits()

    def _point_plots_default(self):
        """
        Instead of defining these in __init__, only
        construct the plots when a ui is requested
        """

        self.point_plotdata = ArrayPlotData(
            peak_times=self.physiodata.peak_times.flatten(),
            x_times=self.physiodata.x_indices - self.physiodata.dzdt_pre_peak,
            lvet=self.physiodata.lvet,
            pep=self.physiodata.pep)

        container = VPlotContainer(resizable="hv",
                                   bgcolor="lightgray",
                                   fill_padding=True,
                                   padding=10)

        for sig in ("pep", "x_times", "lvet"):
            temp_plt = Plot(self.point_plotdata)
            temp_plt.plot(("peak_times", sig), line_width=2)
            temp_plt.title = sig
            container.add(temp_plt)
        container.padding_top = 10

        return container

    mean_widgets = VGroup(VGroup(Item("karcher_plot",
                                      editor=ComponentEditor(),
                                      show_label=False),
                                 Item("registration_plot",
                                      editor=ComponentEditor(),
                                      show_label=False),
                                 show_labels=False),
                          Item("srvf_use_moving_ensembled",
                               label="Use Moving Ensembled dZ/dt"),
                          Item("bspline_before_warping",
                               label="B Spline smoothing"),
                          Item("srvf_t_min", label="Epoch Start Time"),
                          Item("srvf_t_max", label="Epoch End Time"),
                          Item("srvf_lambda", label="Lambda Value"),
                          Item("dzdt_num_inputs_to_group_warping",
                               label="Template uses N beats"),
                          Item("srvf_max_karcher_iterations",
                               label="Max Karcher Iterations"),
                          HGroup(
                              Item("b_calculate_karcher_mean",
                                   label="Step 1:",
                                   enabled_when="dirty"),
                              Item("b_align_all_beats",
                                   label="Step 2:",
                                   enabled_when="karcher_mean_calculated")),
                          label="Initial Karcher Mean")

    mode_widgets = VGroup(
        Item("n_modes", label="Number of Modes/Clusters"),
        Item("max_kmeans_iterations"),
        Item("b_find_modes",
             show_label=False,
             enabled_when="all_beats_registered_to_initial"),
        Item("b_edit_modes",
             show_label=False,
             enabled_when="all_beats_registered_to_mode"))

    traits_view = MEAPView(HSplit(mean_widgets, mode_widgets),
                           resizable=True,
                           win_title="ICG Warping Tools",
                           width=800,
                           height=700,
                           buttons=[OKButton, CancelButton])
Esempio n. 9
0
class MEAPConfig(HasTraits):

    # Parameters for point marking
    apply_ecg_smoothing = CBool(True)
    ecg_smoothing_window_len = Int(5) # NOTE: Changed in 1.1 from 20
    apply_imp_smoothing = CBool(True)
    imp_smoothing_window_len = Int(40)
    apply_bp_smoothing = CBool(True)
    bp_smoothing_window_len = Int(80)

    # Parameters for waveform extraction
    peak_window = CInt(80) #Range(low=5,high=400, value= 200)
    ecg_pre_peak = CInt(300) #Range(low=50, high=500, value=300)
    ecg_post_peak = CInt(400) #Range(low=100, high=700, value=400)
    dzdt_pre_peak = CInt(300) #Range(50, 500, value=300)
    dzdt_post_peak = CInt(700) #Range(100, 1000, value=700)
    doppler_pre_peak = CInt(300) #Range(50, 500, value=300)
    doppler_post_peak = CInt(700) #Range(100, 1000, value=700)
    bp_pre_peak = CInt(300) #Range(50, 500, value=300)
    bp_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    systolic_pre_peak = CInt(300) #Range(50, 500, value=300)
    systolic_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    diastolic_pre_peak = CInt(300) #Range(50, 500, value=300)
    diastolic_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    stroke_volume_equation = Enum("Kubicek","Sramek-Bernstein")
    extraction_group = Group(
        Group(
            Item("peak_window"),
            Item("ecg_pre_peak"),
            Item("ecg_post_peak"),
            Item("dzdt_pre_peak"),
            Item("dzdt_post_peak"),
            Item("bp_pre_peak"),
            Item("bp_post_peak"),
            Item("stroke_volume_equation"),
            orientation="vertical",
            show_border=True,
            label = "Waveform Extraction Parameters",
            springy=True
            ),
        Group(
            Item("mhd_bandpass_min"),
            Item("mhd_bandpass_max"),
            Item("mhd_smoothing_window_len"),
            Item("mhd_smoothing_window"),
            Item("qrs_to_mhd_ratio"),
            Item("combined_smoothing_window_len"),
            Item("combined_smoothing_window"),
            enabled_when="subject_in_mri"
        )
    )

    # Respiration analysis parameters
    process_respiration = CBool(True)
    resp_polort = CInt(7)
    resp_high_freq_cutoff = CFloat(0.35)
    regress_out_resp = Bool(False)

    # parameters for processing the raw data before PT detecting
    # MRI-specific
    subject_in_mri = CBool(False)
    peak_detection_algorithm = Enum("Pan Tomkins 83", "Multisignal", "ECG2")

    # PanTomkins algorithm
    bandpass_min = Range(low=1,high=200,initial=5,value=5)
    bandpass_max = Range(low=1,high=200,initial=15,value=15)
    smoothing_window_len = Range(low=10,high=1000,initial=100,value=100)
    smoothing_window = Enum(SMOOTHING_WINDOWS)
    pt_adjust = Range(low=-2.,high=2.,value=0.00)
    peak_threshold=CFloat
    apply_filter = CBool(True)
    apply_diff_sq = CBool(True)
    apply_smooth_ma = CBool(True)
    pt_params_group = Group(
        Item("apply_filter"),
        Item("bandpass_min"),#editor=RangeEditor(enter_set=True)),
        Item("bandpass_max"),#editor=RangeEditor(enter_set=True)),
        Item("smoothing_window_len"),#editor=RangeEditor(enter_set=True)),
        Item("smoothing_window"),
        Item("pt_adjust"),#editor=RangeEditor(enter_set=True)),
        Item("apply_diff_sq"),
        Item("subject_in_mri"),
        label="R peak detection options",
        show_border=True,
        orientation="vertical",
        springy=True
    )

    # Second signal heartbeat detection
    use_secondary_heartbeat = CBool(False)
    secondary_heartbeat = Enum("dzdt", "pulse_ox", "bp")
    secondary_heartbeat_pre_msec = CInt(400)
    secondary_heartbeat_abs = CBool(True)
    secondary_heartbeat_window = Enum(SMOOTHING_WINDOWS)
    secondary_heartbeat_window_len = CInt(801)
    secondary_heartbeat_n_likelihood_bins = CInt(15)

    # ECG2 parameters
    use_ECG2 = CBool(False)
    qrs_signal_source = Enum("ecg", "ecg2")
    ecg2_weight = Range(low=0., high=1.,value=0.5)

    # Moving Ensembling Parameters
    mea_window_type = Enum("Seconds","Beats")
    mea_n_neighbors = Range(low=0, high=60, value=8)
    mea_window_secs = Range(low=1., high=60, value=15.)
    mea_exp_power = Enum(2,3,4,5,6)
    mea_func_name = Enum("linear","exponential","flat")
    mea_weight_direction = Enum("symmetric", "before", "after")
    mea_smooth_hr = CBool(True)
    use_trimmed_co = CBool(True)

    #bpoint classifier parameters
    bpoint_classifier_pre_point_msec = CInt(20)
    bpoint_classifier_post_point_msec = CInt(20)
    bpoint_classifier_sample_every_n_msec = CInt(1)
    bpoint_classifier_false_distance_min = CInt(5)
    bpoint_classifier_use_bpoint_prior = CBool(True)
    bpoint_classifier_include_derivative = CBool(True)

    # Doppler D point config
    db_point_type = Enum("min", "max")
    db_point_window_len = CInt(20)
    dx_point_type = Enum("min", "max")
    dx_point_window_len = CInt(20)

    # SRVF-warping parameters
    srvf_lambda = CFloat(0.0)
    srvf_max_karcher_iterations = CInt(15)
    srvf_update_min = CFloat(0.01)
    srvf_karcher_mean_subset_size = CInt(30)
    srvf_multi_mode_variance_cutoff = CFloat(0.3)
    srvf_use_moving_ensembled = CBool(False)
    dzdt_num_inputs_to_group_warping = CInt(25)
    srvf_t_min = CInt(0) # Time in msec relative to R
    srvf_t_max = CInt(150) # Also time in msec relative to R
    bspline_before_warping = CBool(True)
    n_modes = CInt(5)
    max_kmeans_iterations = CInt(5)
Esempio n. 10
0
class Importer(HasTraits):
    path = File
    channels = List(Instance(Channel))
    mapping_txt = File()
    mapper = Dict()
    
    # --- Subject information
    subject_age = CFloat(20.)
    subject_gender = Enum("M","F")
    subject_weight = CFloat(135.,label="Weight (lbs)")
    subject_height_ft = CInt(5,label="Height (ft)",
                        desc="Subject's height in feet")
    subject_height_in = CInt(10,label = "Height (in)",
                        desc="Subject's height in inches")
    subject_electrode_distance_front = CFloat(32,
                        label="Impedance electrode distance (front)")
    subject_electrode_distance_back = CFloat(32,
                        label="Impedance electrode distance (back)")
    subject_electrode_distance_right = CFloat(32,
                        label="Impedance electrode distance (right)")
    subject_electrode_distance_left = CFloat(32,
                        label="Impedance electrode distance (left)")
    subject_in_mri = CBool(False)
    subject_control_base_impedance = CFloat(0.,label="Control Imprdance",
                        desc="If in MRI, store the z0 value from outside the MRI")
    subject_resp_max = CFloat(100.,label="Respiration circumference max (cm)")
    subject_resp_min = CFloat(0.,label="Respiration circumference min (cm)")
    
    def _channel_map_default(self):
        return ChannelMapper()
    
    def _mapping_txt_changed(self):
        if not os.path.exists(self.mapping_txt): return
        with open(self.mapping_txt) as f:
            lines = [l.strip() for l in f]
        self.mapper = {}
        for line in lines:
            strip = line.split("->")
            if not len(strip) == 2:
                logger.info("Unable to parse line: %s", line)
                continue
            map_from, map_to = [l.strip() for l in strip]
            if map_to in SUPPORTED_SIGNALS:
                self.mapper[map_from] = map_to
                logger.info("mapping %s to %s",map_from,map_to)
            else:
                logger.info("Unrecognized channel: %s", map_to)
        
    traits_view = MEAPView(
        Group(
            Item("channels", editor=channel_table),
            label="Specify channel contents",
            show_border=True,
            show_labels=False
            ),
         Group(
             Item("subject_age"), Item("subject_gender"),
             Item("subject_height_ft"), Item("subject_height_in"),
             Item("subject_weight"),
             Item("subject_electrode_distance_front"),
             Item("subject_electrode_distance_back"),
             Item("subject_electrode_distance_left"),
             Item("subject_electrode_distance_right"),
             Item("subject_resp_max"),Item("subject_resp_min"),
             Item("subject_in_mri"),Item('subject_control_base_impedance'),
             label="Participant Measurements"
             ),
        buttons=[OKButton, CancelButton],
        resizable=True,
        win_title="Import Data",
        )
    
    
    def guess_channel_contents(self):
        mapped = set([v for k,v in self.mapper.iteritems()])
        for chan in self.channels:
            if chan.name in self.mapper:
                chan.contains = self.mapper[chan.name]
                continue
            cname = chan.name.lower()
            if "magnitude" in cname and not "z0" in mapped:
                chan.contains = "z0"
            elif "ecg" in cname and not "ecg" in mapped:
                chan.contains = "ecg"
            elif "dz/dt" in cname or "derivative" in cname and \
                 not "dzdt" in mapped:
                chan.contains = "dzdt"
            elif "diastolic" in cname and not "diastolic" in mapped:
                chan.contains = "diastolic"
            elif "systolic" in cname and not "systolic" in mapped:
                chan.contains = "systolic"
            elif "blood pressure" in cname or "bp" in cname and \
                 not "bp" in mapped:
                chan.contains = "bp"
            elif "resp" in cname and not "respiration" in mapped:
                chan.contains = "respiration"
            elif "trigger" in cname and not "mri_trigger" in mapped:
                chan.contains = "mri_trigger"
            elif "stimulus" in cname and not "event" in mapped:
                chan.contains = "event"
            
                
    def subject_data(self):
        
        return dict(subject_age = self.subject_age,
            subject_gender = self.subject_gender,
            subject_weight = self.subject_weight,
            subject_height_ft = self.subject_height_ft, 
            subject_height_in = self.subject_height_in,
            subject_electrode_distance_front = self.subject_electrode_distance_front,  
            subject_electrode_distance_back = self.subject_electrode_distance_back ,
            subject_electrode_distance_left = self.subject_electrode_distance_left ,
            subject_electrode_distance_right = self.subject_electrode_distance_right ,
            subject_resp_max = self.subject_resp_max,
            subject_resp_min = self.subject_resp_min,
            subject_in_mri = self.subject_in_mri
            )
    
    def get_physiodata(self):
        raise NotImplementedError("Must be overwritten by subclass")
Esempio n. 11
0
class BatchFileTool(HasTraits):

    # For saving outputs
    file_suffix = Str("_finished.mea.mat")
    input_file_extension = Enum(".mea.mat", ".acq", ".mat")
    overwrite = Bool(False)
    input_directory = Directory()
    output_directory = Directory()
    files = List(Instance(FileToProcess))
    spreadsheet_file = File(exists=False)
    b_save = Button("Save Spreadsheet")

    def _input_file_extension_changed(self):
        self._input_directory_changed()

    def _input_directory_changed(self):
        potential_files = glob(self.input_directory + "/*" +
                               self.input_file_extension)
        potential_files = [f for f in potential_files if not \
                           f.endswith(self.file_suffix) ]

        # Check if the output already exists
        def make_output_file(input_file):
            return input_file[:-len(self.input_file_extension
                                    )] + self.file_suffix

        self.files = [
            FileToProcess(input_file = f, output_file=make_output_file(f)) \
            for f in potential_files
        ]

        # If no output directory is set, use the input directory
        if self.output_directory == '':
            self.output_directory = self.input_directory

    def _b_save_fired(self):
        def get_row():
            return {
                "file": "",
                "outfile": "",
                "weight": "",
                "height_ft": "",
                "weight": "",
                "electrode_distance_front": "",
                "electrode_distance_back": "",
                "electrode_distance_left": "",
                "electrode_distance_right": "",
                "resp_max": "",
                "resp_min": "",
                "in_mri": "",
                "control_base_impedance": ""
            }

        rows = []
        for f in self.files:
            row = get_row()
            row['file'] = f.input_file
            row['outfile'] = f.output_file
            rows.append(row)
        df = pd.DataFrame(rows)
        logger.info("Writing spreadsheet to %s", self.spreadsheet_file)
        df.to_excel(self.spreadsheet_file, index=False)

    mean_widgets = VGroup(Item("input_file_extension"),
                          Item("input_directory"), Item("output_directory"),
                          Item("file_suffix"), Item("spreadsheet_file"),
                          Item("b_save", show_label=False))

    traits_view = MEAPView(HSplit(
        Item("files", editor=files_table, show_label=False), mean_widgets),
                           resizable=True,
                           win_title="Create Batch Spreadsheet",
                           width=800,
                           height=700,
                           buttons=[OKButton, CancelButton])
Esempio n. 12
0
class BatchGroupRegisterDZDT(HasTraits):
    # Dummy physiodata to get defaults
    physiodata = Instance(PhysioData)

    # Parameters for SRVF registration
    srvf_lambda = DelegatesTo("physiodata")
    srvf_max_karcher_iterations = DelegatesTo("physiodata")
    srvf_update_min = DelegatesTo("physiodata")
    srvf_karcher_mean_subset_size = DelegatesTo("physiodata")
    srvf_use_moving_ensembled = DelegatesTo("physiodata")
    bspline_before_warping = DelegatesTo("physiodata")
    dzdt_num_inputs_to_group_warping = DelegatesTo("physiodata")
    srvf_t_min = DelegatesTo("physiodata")
    srvf_t_max = DelegatesTo("physiodata")
    n_modes = DelegatesTo("physiodata")

    # For saving outputs
    num_cores = Int(1)
    file_suffix = Str()
    overwrite = Bool(False)
    input_directory = Directory()
    output_directory = Directory()
    files = List(Instance(FileToProcess))

    b_run = Button(label="Run")

    def __init__(self, **traits):
        super(BatchGroupRegisterDZDT, self).__init__(**traits)
        self.num_cores = multiprocessing.cpu_count()

    def _physiodata_default(self):
        return PhysioData()

    @on_trait_change(("physiodata.srvf_karcher_iterations, "
                      "physiodata.srvf_use_moving_ensembled, "
                      "physiodata.srvf_karcher_mean_subset_size, "
                      "physiodata.bspline_before_warping"))
    def params_edited(self):
        self._input_directory_changed()

    def _num_cores_changed(self):
        num_cores = multiprocessing.cpu_count()

        if self.num_cores < 1 or self.num_cores > num_cores:
            self.num_cores = multiprocessing.cpu_count()

    def _configure_io(self):
        potential_files = glob(self.input_directory + "/*mea.mat")
        potential_files = [f for f in potential_files if not \
                           f.endswith("_aligned.mea.mat") ]

        # Check if the output already exists
        def make_output_file(input_file):
            suffix = os.path.split(os.path.abspath(input_file))[1]
            suffix = suffix[:-len(".mea.mat")] + self.file_suffix
            return self.output_directory + "/" + suffix

        self.files = [
            FileToProcess(input_file = f, output_file=make_output_file(f)) \
            for f in potential_files
        ]

    def _output_directory_changed(self):
        self._configure_io()

    def _file_suffix_changed(self):
        self._configure_io()

    def _input_directory_changed(self):
        self._configure_io()

    def _b_run_fired(self):
        files_to_run = [f for f in self.files if not f.finished]

        logger.info("Processing %d files using %d cpus", len(files_to_run),
                    self.num_cores)

        pool = multiprocessing.Pool(self.num_cores)

        arglist = [
            (f.input_file, f.output_file, self.srvf_lambda,
             self.srvf_max_karcher_iterations, self.srvf_update_min,
             self.srvf_karcher_mean_subset_size,
             self.srvf_use_moving_ensembled, self.bspline_before_warping,
             self.dzdt_num_inputs_to_group_warping, self.srvf_t_min,
             self.srvf_t_max, self.n_modes) for f in files_to_run
        ]
        pool.map(process_physio_file, arglist)

    mean_widgets = VGroup(
        Item("input_directory"), Item("output_directory"), Item("file_suffix"),
        Item("srvf_use_moving_ensembled", label="Use Moving Ensembled dZ/dt"),
        Item("bspline_before_warping", label="B Spline smoothing"),
        Item("srvf_lambda", label="Lambda Value"),
        Item("dzdt_num_inputs_to_group_warping",
             label="Template uses N beats"),
        Item("srvf_max_karcher_iterations", label="Max Iterations"),
        Item("n_modes"), Item("srvf_t_min"), Item("srvf_t_max"),
        Item("num_cores"), Item("b_run", show_label=False))

    traits_view = MEAPView(HSplit(
        Item("files", editor=files_table, show_label=False), mean_widgets),
                           resizable=True,
                           win_title="Batch Warp dZ/dt",
                           width=800,
                           height=700,
                           buttons=[OKButton, CancelButton])
Esempio n. 13
0
class SubjectInfo(HasTraits):
    physiodata = Instance(PhysioData)
    subject_age = DelegatesTo("physiodata")
    subject_gender = DelegatesTo("physiodata")
    subject_weight = DelegatesTo("physiodata")
    subject_height_ft = DelegatesTo("physiodata")
    subject_height_in = DelegatesTo("physiodata")
    subject_electrode_distance_front = DelegatesTo("physiodata")
    subject_electrode_distance_back = DelegatesTo("physiodata")
    subject_electrode_distance_right = DelegatesTo("physiodata")
    subject_electrode_distance_left = DelegatesTo("physiodata")
    subject_resp_max = DelegatesTo("physiodata")
    subject_resp_min = DelegatesTo("physiodata")
    subject_in_mri = DelegatesTo("physiodata")
    subject_control_base_impedance = DelegatesTo("physiodata")

    traits_view = MEAPView(
        VGroup(Item("subject_age"),
               Item("subject_gender"),
               Item("subject_height_ft"),
               Item("subject_height_in"),
               Item("subject_weight"),
               Item("subject_electrode_distance_front"),
               Item("subject_electrode_distance_back"),
               Item("subject_electrode_distance_left"),
               Item("subject_electrode_distance_right"),
               Item("subject_resp_max"),
               Item("subject_resp_min"),
               Item("subject_in_mri"),
               Item('subject_control_base_impedance'),
               label="Participant Measurements"),
        buttons=[OKButton],
        resizable=True,
        win_title="Subject Info",
    )
Esempio n. 14
0
class MEAPPipeline(HasTraits):
    physiodata = Instance(PhysioData)
    file = File
    outfile = File
    mapping_txt = File

    importer = Instance(Importer)
    importer_kwargs = Dict
    b_import = Button(label="Import data")
    b_subject_info = Button(label="Subject Info")
    b_inspect = Button(label="Inspect data")
    b_resp = Button(label="Process Resp")
    b_detect = Button(label="Detect QRS Complexes")
    b_custom_points = Button(label="Label Waveform Points")
    b_moving_ens = Button(label="Compute Moving Ensembles")
    b_register = Button(label="Register dZ/dt")
    b_fmri_tool = Button(label="Process fMRI")
    b_load_experiment = Button(label="Load Design File")
    b_save = Button(label="save .meap file")
    b_clear_mem = Button(label="Clear Memory")
    interactive = Bool(False)
    saved = Bool(False)
    peaks_detected = Bool(False)
    global_ensemble_marked = Bool(False)

    def _b_clear_mem_fired(self):
        self.physiodata = None
        self.global_ensemble_marked = False
        self.peaks_detected = False

    def import_data(self):
        logger.info("Loading %s", self.file)

        for k, v in self.importer_kwargs.iteritems():
            logger.info("Using attr '%s' from excel file", k)

        if not os.path.exists(self.file):
            fail("unable to open file %s" % self.file,
                 interactive=self.interactive)
            return

        elif self.file.endswith(".acq"):
            importer = AcqImporter(path=str(self.file),
                                   mapping_txt=self.mapping_txt,
                                   **self.importer_kwargs)

        elif self.file.endswith(".mea.mat"):
            pd = load_from_disk(self.file)
            self.physiodata = pd
            if self.physiodata.using_hand_marked_point_priors:
                self.global_ensemble_marked = True
            if self.physiodata.peak_indices.size > 0:
                self.peaks_detected = True
            return
        elif self.file.endswith(".mat"):
            importer = MatfileImporter(path=str(self.file),
                                       **self.importer_kwargs)
        else:
            return

        ui = importer.edit_traits(kind="livemodal")
        if not ui.result:
            logger.info("Import cancelled")
            return
        self.physiodata = importer.get_physiodata()

    def _b_import_fired(self):
        self.import_data()

    def _b_resp_fired(self):
        RespirationProcessor(physiodata=self.physiodata).edit_traits(
            kind="livemodal")

    def _b_subject_info_fired(self):
        SubjectInfo(physiodata=self.physiodata).edit_traits(kind="livemodal")

    def _b_custom_points_fired(self):
        ge = GlobalEnsembleAveragedHeartBeat(physiodata=self.physiodata)
        if not self.physiodata.using_hand_marked_point_priors:
            ge.mark_points()
        ge.edit_traits(kind="livemodal")
        self.physiodata.using_hand_marked_point_priors = True
        self.global_ensemble_marked = True

    def _b_inspect_fired(self):
        DataPlot(physiodata=self.physiodata).edit_traits()

    def _b_detect_fired(self):
        detector = PanTomkinsDetector(physiodata=self.physiodata)
        ui = detector.edit_traits(kind="livemodal")
        if ui.result:
            self.peaks_detected = True

    def _b_moving_ens_fired(self):
        MovingEnsembler(physiodata=self.physiodata).edit_traits()

    def _b_register_fired(self):
        GroupRegisterDZDT(physiodata=self.physiodata).edit_traits()

    def _b_save_fired(self):
        print "writing", self.outfile
        if os.path.exists(self.outfile):
            logger.warn("%s already exists", self.outfile)
        self.physiodata.save(self.outfile)
        self.saved = True
        logger.info("saved %s", self.outfile)

    def _b_fmri_tool_fired(self):
        FMRITool(physiodata=self.physiodata).edit_traits()

    traits_view = MEAPView(
        VGroup(
            VGroup(Item("file"), Item("outfile")),
            VGroup(
                spring,
                Item("b_import",
                     show_label=False,
                     enabled_when=
                     "file.endswith('.acq') or file.endswith('.mat')"),
                Item("b_inspect", enabled_when="physiodata is not None"),
                Item("b_subject_info", enabled_when="physiodata is not None"),
                Item("b_resp", enabled_when="physiodata is not None"),
                Item("b_detect", enabled_when="physiodata is not None"),
                Item("b_custom_points", enabled_when="peaks_detected"),
                Item("b_moving_ens", enabled_when="global_ensemble_marked"),
                Item("b_register", enabled_when="peaks_detected"),
                Item("b_fmri_tool",
                     enabled_when=
                     "physiodata.processed_respiration_data.size > 0"),
                Item("b_save",
                     enabled_when=
                     "outfile.endswith('.mea') or outfile.endswith('.mea.mat')"
                     ),
                Item("b_clear_mem"),
                spring,
                show_labels=False)))
Esempio n. 15
0
class PanTomkinsDetector(HasTraits):
    # Holds the data Source
    physiodata = Instance(PhysioData)
    ecg_ts = Instance(TimeSeries)
    dirty = Bool(False)

    # For visualization
    plot = Instance(Plot, transient=True)
    plot_data = Instance(ArrayPlotData, transient=True)
    image_plot = Instance(Plot, transient=True)
    image_plot_data = Instance(ArrayPlotData, transient=True)
    image_selection_tool = Instance(ImageRangeSelector)
    image_plot_selected = DelegatesTo("image_selection_tool")
    peak_vis = Instance(ScatterPlot, transient=True)
    scatter = Instance(ScatterPlot, transient=True)
    start_time = Range(low=0, high=100000.0, initial=0.0)
    window_size = Range(low=1.0, high=100000.0, initial=30.0)
    peak_editor = Instance(PeakPickingTool, transient=True)
    show_censor_regions = Bool(True)

    # parameters for processing the raw data before PT detecting
    bandpass_min = DelegatesTo("physiodata")
    bandpass_max = DelegatesTo("physiodata")
    smoothing_window_len = DelegatesTo("physiodata")
    smoothing_window = DelegatesTo("physiodata")
    pt_adjust = DelegatesTo("physiodata")
    peak_threshold = DelegatesTo("physiodata")
    apply_filter = DelegatesTo("physiodata")
    apply_diff_sq = DelegatesTo("physiodata")
    apply_smooth_ma = DelegatesTo("physiodata")
    peak_window = DelegatesTo("physiodata")
    qrs_source_signal = DelegatesTo("physiodata")
    qrs_power_signal = Property(Array,
                                depends_on=qrs_power_signal_dependencies)

    # Use a secondary signal to limit the search range
    use_secondary_heartbeat = DelegatesTo("physiodata")
    secondary_heartbeat = DelegatesTo("physiodata")
    secondary_heartbeat_abs = DelegatesTo("physiodata")
    secondary_heartbeat_pre_msec = DelegatesTo("physiodata")
    secondary_heartbeat_window = DelegatesTo("physiodata")
    secondary_heartbeat_window_len = DelegatesTo("physiodata")
    secondary_heartbeat_n_likelihood_bins = DelegatesTo("physiodata")
    aux_signal = Property(Array)  #, depends_on=aux_signal_dependencies)

    # Secondary ECG signal
    can_use_ecg2 = Bool(False)
    use_ECG2 = DelegatesTo("physiodata")
    ecg2_weight = DelegatesTo("physiodata")

    # The results of the peak search
    peak_indices = DelegatesTo("physiodata")
    peak_times = DelegatesTo("physiodata")
    peak_values = Array
    dne_peak_indices = DelegatesTo("physiodata")
    dne_peak_times = DelegatesTo("physiodata")
    dne_peak_values = Array
    # Holds the timeseries for the signal vs noise thresholds
    thr_times = Array
    thr_vals = Array

    # for pulling out data for aggregating ecg_ts, dzdt_ts and bp beats
    ecg_matrix = DelegatesTo("physiodata")

    # UI elements
    b_detect = Button(label="Detect QRS", transient=True)
    b_shape = Button(label="Shape-Based Tuning", transient=True)

    def __init__(self, **traits):
        super(PanTomkinsDetector, self).__init__(**traits)
        self.ecg_ts = TimeSeries(physiodata=self.physiodata, contains="ecg")
        # relevant data to be plotted
        self.ecg_time = self.ecg_ts.time

        self.can_use_ecg2 = "ecg2" in self.physiodata.contents

        # Which ECG signal to use?
        if self.qrs_source_signal == "ecg" or not self.can_use_ecg2:
            self.ecg_signal = normalize(self.ecg_ts.data)
            if self.can_use_ecg2:
                self.ecg2_signal = normalize(self.physiodata.ecg2_data)
            else:
                self.ecg2_signal = None
        else:
            self.ecg_signal = normalize(self.physiodata.ecg2_data)
            self.ecg2_signal = normalize(self.physiodata.ecg_data)

        self.thr_times = self.ecg_time[np.array([0, -1])]
        self.censored_intervals = self.physiodata.censored_intervals
        # graphics containers
        self.aux_window_graphics = []
        self.censored_region_graphics = []
        # If peaks are already available in the mea.mat,
        # they have already been marked
        if len(self.peak_times):
            self.dirty = False
            self.peak_values = self.ecg_signal[self.peak_indices]
        if self.dne_peak_indices is not None and len(self.dne_peak_indices):
            self.dne_peak_values = self.ecg_signal[self.dne_peak_indices]

    def _b_shape_fired(self):
        self.shape_based_tuning()

    def shape_based_tuning(self):
        logger.info("Running shape-based tuning")
        qrs_stack = peak_stack(self.peak_indices,
                               self.qrs_power_signal,
                               pre_msec=300,
                               post_msec=700,
                               sampling_rate=self.physiodata.ecg_sampling_rate)

    def _b_detect_fired(self):
        self.detect()

    def _show_censor_regions_changed(self):
        for crg in self.censored_region_graphics:
            crg.visible = self.show_censor_regions
        self.plot.request_redraw()

    @cached_property
    def _get_qrs_power_signal(self):
        if self.apply_filter:
            filtered_ecg = normalize(
                bandpass(self.ecg_signal, self.bandpass_min, self.bandpass_max,
                         self.ecg_ts.sampling_rate))
        else:
            filtered_ecg = self.ecg_signal

        # Differentiate and square the signal?
        if self.apply_diff_sq:
            # Differentiate the signal and square it
            diff_sq = np.ediff1d(filtered_ecg, to_begin=0)**2
        else:
            diff_sq = filtered_ecg

        # If we're to apply a Moving Average smoothing
        if self.apply_smooth_ma:
            # MA smoothing
            smooth_ma = smooth(diff_sq,
                               window_len=self.smoothing_window_len,
                               window=self.smoothing_window)
        else:
            smooth_ma = diff_sq
        if not self.use_ECG2:
            # for visualization purposes
            return normalize(smooth_ma)
        logger.info("Using 2nd ECG signal")

        # Use secondary ECG signal and combine QRS power
        if self.apply_filter:
            filtered_ecg2 = normalize(
                bandpass(self.ecg2_signal, self.bandpass_min,
                         self.bandpass_max, self.ecg_ts.sampling_rate))
        else:
            filtered_ecg2 = self.ecg2_signal
        if self.apply_diff_sq:
            diff_sq2 = np.ediff1d(filtered_ecg2, to_begin=0)**2
        else:
            diff_sq2 = filtered_ecg2

        if self.apply_smooth_ma:
            smooth_ma2 = smooth(diff_sq2,
                                window_len=self.smoothing_window_len,
                                window=self.smoothing_window)
        else:
            smooth_ma2 = diff_sq2

        return normalize(((1 - self.ecg2_weight) * smooth_ma +
                          self.ecg2_weight * smooth_ma2)**2)

    def _get_aux_signal(self):
        if not self.use_secondary_heartbeat: return np.array([])
        sig = getattr(self.physiodata, self.secondary_heartbeat + "_data")
        if self.secondary_heartbeat_abs:
            sig = np.abs(sig)
        if self.secondary_heartbeat_window_len > 0:
            sig = smooth(sig,
                         window_len=self.secondary_heartbeat_window_len,
                         window=self.secondary_heartbeat_window)
        return normalize(sig)

    @on_trait_change(algorithm_parameters)
    def params_edited(self):
        self.dirty = True

    def update_aux_window_graphics(self):
        if not hasattr(self, "ecg_trace") or not self.use_secondary_heartbeat:
            return
        #remove the old graphics
        for aux_window in self.aux_window_graphics:
            del self.ecg_trace.index.metadata[aux_window.metadata_name]
            self.ecg_trace.overlays.remove(aux_window)
        # add the new graphics
        for n, (start, end) in enumerate(self.aux_windows / 1000.):
            window_key = "aux%03d" % n
            self.aux_window_graphics.append(
                AuxSignalWindow(component=self.aux_trace,
                                metadata_name=window_key))
            self.ecg_trace.overlays.append(self.aux_window_graphics[-1])
            self.ecg_trace.index.metadata[window_key] = start, end

    def get_aux_windows(self):
        aux_peaks = find_peaks(self.aux_signal)
        aux_pre_peak = aux_peaks - self.secondary_heartbeat_pre_msec
        aux_windows = np.column_stack([aux_pre_peak, aux_peaks])
        return aux_windows

    def detect(self):
        """
        Implementation of the Pan Tomkins QRS detector
        """
        logger.info("Beginning QRS Detection")
        t0 = time.time()

        # The original paper used a different method for finding peaks
        smoothdiff = normalize(np.ediff1d(self.qrs_power_signal, to_begin=0))
        peaks = find_peaks(smoothdiff)
        peak_amps = self.qrs_power_signal[peaks]

        # Part 2: getting rid of useless peaks
        # ====================================
        # There are lots of small, irrelevant peaks that need to
        # be discarded.
        # 2a) remove peaks occurring in censored intervals
        censor_peak_mask = censor_peak_times(
            self.ecg_ts.
            censored_regions,  #TODO: switch to physiodata.censored_intervals   
            self.ecg_time[peaks])
        n_removed_peaks = peaks.shape[0] - censor_peak_mask.sum()
        logger.info("%d/%d potential peaks outside %d censored intervals",
                    n_removed_peaks, peaks.shape[0],
                    len(self.ecg_ts.censored_regions))
        peaks = peaks[censor_peak_mask]
        peak_amps = peak_amps[censor_peak_mask]

        # 2b) if a second signal is used, make sure the ecg peaks are
        #     near aux signal peaks
        if self.use_secondary_heartbeat:
            self.aux_windows = self.get_aux_windows()
            aux_mask = times_contained_in(peaks, self.aux_windows)
            logger.info("Using secondary signal")
            logger.info("%d aux peaks detected", self.aux_windows.shape[0])
            logger.info("%d/%d peaks contained in second signal window",
                        aux_mask.sum(), peaks.shape[0])
            peaks = peaks[aux_mask]
            peak_amps = peak_amps[aux_mask]

        # 2c) use otsu's method to find a cutoff value
        peak_amp_thr = threshold_otsu(peak_amps) + self.pt_adjust
        otsu_mask = peak_amps > peak_amp_thr
        logger.info("otsu threshold: %.7f", peak_amp_thr)
        logger.info("%d/%d peaks survive", otsu_mask.sum(), peaks.shape[0])
        peaks = peaks[otsu_mask]
        peak_amps = peak_amps[otsu_mask]

        # 3) Make sure there is only one peak in each secondary window
        #    this is accomplished by estimating the distribution of peak
        #    times relative to aux window starts
        if self.use_secondary_heartbeat:
            npeaks = peaks.shape[0]
            # obtain a distribution of times and amplitudes
            rel_peak_times = np.zeros_like(peaks)
            window_subsets = []
            for start, end in self.aux_windows:
                mask = np.flatnonzero((peaks >= start) & (peaks <= end))
                if len(mask) == 0: continue
                window_subsets.append(mask)
                rel_peak_times[mask] = peaks[mask] - start

            # how common are relative peak times relative to window starts
            densities, bins = np.histogram(
                rel_peak_times,
                range=(0, self.secondary_heartbeat_pre_msec),
                bins=self.secondary_heartbeat_n_likelihood_bins,
                density=True)

            likelihoods = densities[np.clip(
                np.digitize(rel_peak_times, bins) - 1, 0,
                self.secondary_heartbeat_n_likelihood_bins - 1)]
            _peaks = []
            # Pull out the maximal peak
            for subset in window_subsets:
                # If there's only a single peak contained, no need for math
                if len(subset) == 1:
                    _peaks.append(peaks[subset[0]])
                    continue

                _peaks.append(peaks[subset[np.argmax(likelihoods[subset])]])
            peaks = np.array(_peaks)
            logger.info("Only 1 peak per aux window allowed:" + \
                   " %d/%d peaks remaining",peaks.shape[0],npeaks)

        # Check that no two peaks are too close:
        peak_diffs = np.ediff1d(peaks, to_begin=500)
        peaks = peaks[peak_diffs > 200]  # Cutoff is 300BPM

        # Stack the peaks and see if the original data has a higher value
        raw_stack = peak_stack(peaks,
                               self.ecg_signal,
                               pre_msec=self.peak_window,
                               post_msec=self.peak_window,
                               sampling_rate=self.ecg_ts.sampling_rate)
        adj_factors = np.argmax(raw_stack, axis=1) - self.peak_window
        peaks = peaks + adj_factors

        self.peak_indices = peaks
        self.peak_values = self.ecg_signal[peaks]
        self.peak_times = self.ecg_time[peaks]
        self.thr_vals = np.array([peak_amp_thr] * 2)
        t1 = time.time()
        # update the scatterplot if we're interactive
        if self.plot_data is not None:
            self.plot_data.set_data("peak_times", self.peak_times)
            self.plot_data.set_data("peak_values", self.peak_values)
            self.plot_data.set_data("qrs_power", self.qrs_power_signal)
            self.plot_data.set_data("aux_signal", self.aux_signal)
            self.plot_data.set_data("thr_vals", self.thr_vals)
            self.plot_data.set_data("thr_times", self.thr_vals)
            self.plot_data.set_data(
                "thr_times", np.array([self.ecg_time[0], self.ecg_time[-1]])),
            self.update_aux_window_graphics()
            self.plot.request_redraw()
            self.image_plot_data.set_data("imagedata", self.ecg_matrix)
            self.image_plot.request_redraw()
        else:
            print "plot data is none"
        logger.info("found %d QRS complexes in %.3f seconds",
                    len(self.peak_indices), t1 - t0)
        self.dirty = False

    def _plot_default(self):
        """
        Creates a plot of the ecg_ts data and the signals derived during
        the Pan Tomkins algorithm
        """
        # Create plotting components
        self.plot_data = ArrayPlotData(
            time=self.ecg_time,
            raw_ecg=self.ecg_signal,
            qrs_power=self.qrs_power_signal,
            aux_signal=self.aux_signal,
            # Ensembleable peaks
            peak_times=self.peak_times,
            peak_values=self.peak_values,
            # Non-ensembleable, useful for HR, HRV
            dne_peak_times=self.dne_peak_times,
            dne_peak_values=self.dne_peak_values,
            # Threshold slider bar
            thr_times=np.array([self.ecg_time[0], self.ecg_time[-1]]),
            thr_vals=np.array([0, 0]))

        plot = Plot(self.plot_data, use_backbuffer=True)
        self.aux_trace = plot.plot(("time", "aux_signal"),
                                   color="purple",
                                   alpha=0.6,
                                   line_width=2)[0]
        self.qrs_power_trace = plot.plot(("time", "qrs_power"),
                                         color="blue",
                                         alpha=0.8)[0]
        self.ecg_trace = plot.plot(("time", "raw_ecg"),
                                   color="red",
                                   line_width=2)[0]

        # Load the censor regions and add them to the plot
        for n, (start, end) in enumerate(self.censored_intervals):
            censor_key = "censor%03d" % n
            self.censored_region_graphics.append(
                StaticCensoredInterval(component=self.qrs_power_trace,
                                       metadata_name=censor_key))
            self.ecg_trace.overlays.append(self.censored_region_graphics[-1])
            self.ecg_trace.index.metadata[censor_key] = start, end

        # Line showing
        self.threshold_trace = plot.plot(("thr_times", "thr_vals"),
                                         color="black",
                                         line_width=3)[0]

        # Plot for plausible peaks
        self.ppeak_vis = plot.plot(("dne_peak_times", "dne_peak_values"),
                                   type="scatter",
                                   marker="diamond")[0]

        # Make a scatter plot where you can edit the peaks
        self.peak_vis = plot.plot(("peak_times", "peak_values"),
                                  type="scatter")[0]
        self.scatter = plot.components[-1]
        self.peak_editor = PeakPickingTool(self.scatter)
        self.scatter.tools.append(self.peak_editor)
        # when the user adds or removes a point, automatically extract
        self.on_trait_event(self._change_peaks, "peak_editor.done_selecting")

        self.scatter.overlays.append(
            PeakPickingOverlay(component=self.scatter))

        return plot

    # TODO: have this update other variables in physiodata: hand_labeled, mea, etc
    def _delete_peaks(self, peaks_to_delete):
        pass

    def _add_peaks(self, peaks_to_add):
        pass

    # End TODO

    def _change_peaks(self):
        if self.dirty:
            messagebox("You shouldn't edit peaks until you've run Detect QRS")
        interval = self.peak_editor.selection
        if interval is None: return
        mode = self.peak_editor.selection_purpose
        logger.info("PeakPickingTool entered %s mode over %s", mode,
                    str(interval))
        if mode == "delete":
            # Do it for real peaks
            ok_peaks = np.logical_not(
                np.logical_and(self.peak_times > interval[0],
                               self.peak_times < interval[1]))
            self.peak_indices = self.peak_indices[ok_peaks]
            self.peak_values = self.peak_values[ok_peaks]
            self.peak_times = self.peak_times[ok_peaks]
            # And do it for
            ok_dne_peaks = np.logical_not(
                np.logical_and(self.dne_peak_times > interval[0],
                               self.dne_peak_times < interval[1]))
            self.dne_peak_indices = self.dne_peak_indices[ok_dne_peaks]
            self.dne_peak_values = self.dne_peak_values[ok_dne_peaks]
            self.dne_peak_times = self.dne_peak_times[ok_dne_peaks]
        else:
            # Find the signal contained in the selection
            sig = np.logical_and(self.ecg_time > interval[0],
                                 self.ecg_time < interval[1])
            sig_inds = np.flatnonzero(sig)
            selected_sig = self.ecg_signal[sig_inds]
            # find the peak in the selected region
            peak_ind = sig_inds[0] + np.argmax(selected_sig)

            # If we're in add peak mode, always make sure
            # that only unique peaks get added:
            if mode == "add":
                real_peaks = np.sort(
                    np.unique(self.peak_indices.tolist() + [peak_ind]))
                self.peak_indices = real_peaks
                self.peak_values = self.ecg_signal[real_peaks]
                self.peak_times = self.ecg_time[real_peaks]
            else:
                real_dne_peaks = np.sort(
                    np.unique(self.dne_peak_indices.tolist() + [peak_ind]))
                self.dne_peak_indices = real_dne_peaks
                self.dne_peak_values = self.ecg_signal[real_dne_peaks]
                self.dne_peak_times = self.ecg_time[real_dne_peaks]

        # update the scatterplot if we're interactive
        if not self.plot_data is None:
            self.plot_data.set_data("peak_times", self.peak_times)
            self.plot_data.set_data("peak_values", self.peak_values)
            self.plot_data.set_data("dne_peak_times", self.dne_peak_times)
            self.plot_data.set_data("dne_peak_values", self.dne_peak_values)
            self.image_plot_data.set_data("imagedata", self.ecg_matrix)
            self.image_plot.request_redraw()

    @on_trait_change("window_size,start_time")
    def update_plot_range(self):
        self.plot.index_range.high = self.start_time + self.window_size
        self.plot.index_range.low = self.start_time

    @on_trait_change("image_plot_selected")
    def snap_to_image_selection(self):
        if not self.peak_times.size > 0: return
        tmin = self.peak_times[int(self.image_selection_tool.ymin)] - 2.
        tmax = self.peak_times[int(self.image_selection_tool.ymax)] + 2.
        logger.info("selection tool sends data to (%.2f, %.2f)", tmin, tmax)
        self.plot.index_range.low = tmin - 2.
        self.plot.index_range.high = tmax + 2.

    def _image_plot_default(self):
        # for image plot
        img = self.ecg_matrix
        if self.ecg_matrix.size == 0:
            img = np.zeros((100, 100))
        self.image_plot_data = ArrayPlotData(imagedata=img)
        plot = Plot(self.image_plot_data)
        self.image_selection_tool = ImageRangeSelector(component=plot,
                                                       tool_mode="range",
                                                       axis="value",
                                                       always_on=True)
        plot.img_plot("imagedata",
                      colormap=jet,
                      name="plot1",
                      origin="bottom left")[0]
        plot.overlays.append(self.image_selection_tool)
        return plot

    detection_params_group = VGroup(
        HGroup(
            Group(
                VGroup(Item("use_secondary_heartbeat"),
                       Item("peak_window"),
                       Item("apply_filter"),
                       Item("bandpass_min"),
                       Item("bandpass_max"),
                       Item("smoothing_window_len"),
                       Item("smoothing_window"),
                       Item("pt_adjust"),
                       Item("apply_diff_sq"),
                       label="Pan Tomkins"),
                VGroup(  # Items for MultiSignal detection
                    Item("secondary_heartbeat"),
                    Item("secondary_heartbeat_pre_msec"),
                    Item("secondary_heartbeat_abs"),
                    Item("secondary_heartbeat_window"),
                    Item("secondary_heartbeat_window_len"),
                    label="MultiSignal Detection",
                    enabled_when="use_secondary_heartbeat"),
                VGroup(  # Items for two ECG Signals
                    Item("use_ECG2"),
                    Item("ecg2_weight"),
                    label="Multiple ECG",
                    enabled_when="can_use_ecg2"),
                label="Beat Detection Options",
                layout="tabbed",
                show_border=True,
                springy=True),
            Item("image_plot", editor=ComponentEditor(), width=300),
            show_labels=False),
        Item("b_detect", enabled_when="dirty"),
        Item("b_shape"),
        show_labels=False)

    plot_group = Group(
        Group(Item("plot", editor=ComponentEditor(), width=800),
              show_labels=False), Item("start_time"), Item("window_size"))
    traits_view = MEAPView(
        VSplit(
            plot_group,
            detection_params_group,
            #orientation="vertical"
        ),
        resizable=True,
        buttons=[OKButton, CancelButton],
        title="Detect Heart Beats")
Esempio n. 16
0
class MEAPTimeseries(HasTraits):
    physiodata = Instance(PhysioData)
    plot = Instance(Plot, transient=True)
    plotdata = Instance(ArrayPlotData, transient=True)
    selection = Instance(BeatSelection, transient=True)
    signal = Str
    selected_beats = Array
    index_datasourse = Instance(DataRange1D)
    marker_size = Array
    metadata_name = Str
    selected = Event
    selected_range = Tuple

    traits_view = MEAPView(Group(Item('plot',
                                      editor=ComponentEditor(size=size),
                                      show_label=False),
                                 orientation="vertical"),
                           resizable=True,
                           win_title=title)

    def __init__(self, **traits):
        super(MEAPTimeseries, self).__init__(**traits)
        self.metadata_name = self.signal + "_selection"

    def _selection_changed(self):
        selection = self.selection.selection
        logger.info("%s selection changed to %s", self.signal, str(selection))
        if selection is None or len(selection) == 0:
            return
        self.selected_range = selection
        self.selected = True

    def _plot_default(self):

        plot = Plot(self.plotdata)

        if self.signal in ("tpr", "co", "sv"):
            rc_signal = "resp_corrected_" + self.signal
            if rc_signal in self.plotdata.arrays and self.plotdata.arrays[
                    rc_signal].size > 0:
                # Plot the resp-uncorrected version underneath
                plot.plot(("peak_times", self.signal),
                          type="line",
                          color="purple")
                signal = rc_signal
            signal = self.signal
        elif self.signal == "hr":
            plot.plot(("peak_times", self.signal), type="line", color="purple")
            signal = "mea_hr"
        else:
            signal = self.signal

        # Create the plot
        plot.plot(("peak_times", signal), type="line", color="blue")
        plot.plot(("peak_times", signal, "beat_type"),
                  type="cmap_scatter",
                  color_mapper=jet,
                  name="my_plot",
                  marker="circle",
                  border_visible=False,
                  outline_color="transparent",
                  bg_color="white",
                  index_sort="ascending",
                  marker_size=3,
                  fill_alpha=0.8)

        # Tweak some of the plot properties
        plot.title = self.signal  #"Scatter Plot With Lasso Selection"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.line_width = 1
        plot.padding = 20

        # Right now, some of the tools are a little invasive, and we need the
        # actual ScatterPlot object to give to them
        my_plot = plot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.selection = BeatSelection(component=my_plot)
        my_plot.active_tool = self.selection
        selection_overlay = RangeSelectionOverlay(component=my_plot)
        my_plot.overlays.append(selection_overlay)

        # Set up the trait handler for the selection
        self.selection.on_trait_change(self._selection_changed,
                                       'selection_completed')

        return plot
Esempio n. 17
0
class FMRITool(HasTraits):
    physiodata = Instance(PhysioData)
    processed_respiration_data = DelegatesTo("physiodata")
    processed_respiration_time = DelegatesTo("physiodata")

    slice_times_txt = Str
    slice_times = Array
    slice_times_matrix = Array
    acquisition_type = Enum("Coronal", "Saggittal", "Axial")
    dicom_acquisition_direction = Enum("LPS+", "RAS+")
    unsure_about_direction = Bool(False)
    correction_type = Enum(values="possible_corrections")
    possible_corrections = List(["Slicewise Regression"])
    respiration_expansion_order = Int(4)
    drift_model_order = Int(6)
    add_back_intensity = Bool(True)
    cardiac_expansion_order = Int(3)
    interaction_expansion_order = Int(1)
    physio_design_file = File

    # Local arrays that have been restricted to during-scanning times
    r_times = Array  # Beat times occurring during scanning
    tr_onsets = Array  # TR times
    resp_times = Array
    resp_signal = Array

    # Nifti Stuff
    fmri_file = File
    fmri_mask = File
    denoised_output = File
    regression_output = File
    nuisance_output_prefix = Str
    b_calculate = Button(label="RUN")
    write_4d_nuisance = Bool(False)
    write_4d_nuisance_for_each_modality = Bool(False)
    output_units = Enum("Percent Signal Change", "Raw BOLD")
    fmri_tr = Float

    interactive = Bool(False)
    missing_triggers = Int(0)
    missing_volumes = Int(0)

    traits_view = MEAPView(VGroup(
        VGroup(Item("fmri_file"),
               Item("fmri_mask"),
               Item("slice_times_txt"),
               Item("acquisition_type"),
               Item("dicom_acquisition_direction"),
               Item("unsure_about_direction"),
               Item("denoised_output"),
               Item("regression_output"),
               Item("nuisance_output_prefix"),
               Item("write_4d_nuisance"),
               Item("write_4d_nuisance_for_each_modality"),
               show_border=True,
               label="fMRI I/O"),
        VGroup(Item("cardiac_expansion_order"),
               Item("respiration_expansion_order"),
               Item("interaction_expansion_order"),
               Item("correction_type"),
               Item("drift_model_order"),
               Item("add_back_intensity"),
               Item("output_units"),
               Item("physio_design_file"),
               show_border=True,
               label="RETROICOR")),
                           Item("b_calculate", show_label=False),
                           win_title="RETROICOR")

    def _load_slice_timings(self):
        try:
            _slices = np.loadtxt(self.slice_times_txt)
        except Exception, e:
            messagebox("Unable to load slice timings:\n%s" % e)
            raise ValueError

        # Check if it's in the mid-TR format
        if np.any(_slices < 0):
            logger.info("Using the mid-TR slice convention")
            if self.fmri_tr <= 0:
                raise ValueError("fMRI file has to be set before using mid-TR")
            _slices = (_slices + 0.5) * self.fmri_tr

        # Check if the times are in msec or
        if np.any(_slices) > 10:
            logger.info("Converting slice times from msec to sec")
            _slices = _slices / 1000.
        else:
            logger.info("Assuming slice timings are in seconds")

        # Slice times should be one row per TR, and one column
        # per slice.
        self.slice_times = _slices
        logger.info("Using slice timings of %s", str(self.slice_times))
Esempio n. 18
0
class TimeSeries(HasTraits):
    # Holds data
    name = Str
    contains = Enum(SUPPORTED_SIGNALS)
    plot = Instance(Plot, transient=True)
    data = Array
    time = Property(Array, depends_on=["start_time", "sampling_rate", "data"])

    # Plotting options
    visible = Bool(True)
    ymax = Float()
    ymin = Float()
    line_type = marker_trait
    line_color = ColorTrait("blue")
    plot_type = Enum("line", "scatter")
    censored_regions = List(Instance(CensorRegion))
    b_add_censor = Button(label="Add CensorRegion", transient=True)
    b_zoom_y = Button(label="Zoom y", transient=True)
    b_info = Button(label="Info", transient=True)
    b_clear_censoring = Button(label="Clear Censoring", transient=True)
    renderer = Instance(LinePlot, transient=True)

    # For the winsorizing steps
    winsor_swap = Array
    winsor_min = Float(0.005)
    winsor_max = Float(0.005)
    winsorize = Bool(False)
    winsor_enable = Bool(True)

    def __init__(self, **traits):
        """
        Class to represent data collected over time
        """
        super(TimeSeries, self).__init__(**traits)
        if not self.contains in self.physiodata.contents:
            raise ValueError("Signal not found in data")
        self.name = self.contains

        # First, check whether it's already winsorized
        winsorize_trait = self.name + "_winsorize"
        if getattr(self.physiodata, winsorize_trait):
            self.winsor_enable = False
        self.winsor_min = getattr(self.physiodata, self.name + "_winsor_min")
        self.winsor_max = getattr(self.physiodata, self.name + "_winsor_max")

        # Load the actual data
        self.data = getattr(self.physiodata, self.contains + "_data")
        self.winsor_swap = self.data.copy()

        self.sampling_rate = getattr(self.physiodata,
                                     self.contains + "_sampling_rate")
        self.sampling_rate_unit = getattr(
            self.physiodata, self.contains + "_sampling_rate_unit")
        self.start_time = getattr(self.physiodata,
                                  self.contains + "_start_time")
        """
        The censored regions are loaded from physiodata INITIALLY.
        from that point on the censored regions are accessed from
        physiodata's
        """

        self.line_color = colors[self.contains]
        self.n_censor_intervals = 0
        for (start, end), source in zip(self.physiodata.censored_intervals,
                                        self.physiodata.censoring_sources):
            if str(source) == self.contains:
                self.censored_regions.append(
                    CensorRegion(start_time=start,
                                 end_time=end,
                                 metadata_name=self.__get_metadata_name()))

    def __get_metadata_name(self):
        name = self.contains + "%03d" % self.n_censor_intervals
        self.n_censor_intervals += 1
        return name

    def _winsorize_changed(self):
        if self.winsorize:
            logger.info("Winsorizing %s with limits=(%.5f%.5f)", self.name,
                        self.winsor_min, self.winsor_max)
            # Take the original data and replace it with the winsorized version
            self.data = np.array(
                winsorize(self.winsor_swap,
                          limits=(self.winsor_min, self.winsor_max)))
            setattr(self.physiodata, self.name + "_winsor_min",
                    self.winsor_min)
            setattr(self.physiodata, self.name + "_winsor_max",
                    self.winsor_max)

        else:
            logger.info("Restoring %s to its original data", self.name)
            self.data = self.winsor_swap.copy()
        setattr(self.physiodata, self.contains + "_data", self.data)
        setattr(self.physiodata, self.name + "_winsorize", self.winsorize)
        self.plot.range2d.y_range.low = self.data.min()
        self.plot.range2d.y_range.high = self.data.max()
        self.plot.request_redraw()

    def __str__(self):
        descr = "Timeseries: %s\n" % self.name
        descr += "-" * len(descr) + "\n\t" + \
               "\n\t".join([
                   "Sampling rate: %.4f%s" % (self.sampling_rate,self.sampling_rate_unit),
                   "N samples: %d" % self.data.shape[0],
                   "N censored intervals: %d" % len(self.censored_regions),
                   "Start time: %.3f" % self.start_time
                   ])
        return descr

    def _plot_default(self):
        # Create plotting components
        plotdata = ArrayPlotData(time=self.time, data=self.data)
        plot = Plot(plotdata)
        self.renderer = plot.plot(("time", "data"), color=self.line_color)[0]
        plot.title = self.name
        plot.title_position = "right"
        plot.title_angle = 270
        plot.line_width = 1
        plot.padding = 25
        plot.width = 400

        # Load the censor regions and add them to the plot
        for censor_region in self.censored_regions:
            # Make a censor region on the Timeseries object
            censor_region.plot = self.renderer
            censor_region.viz
            censor_region.set_limits(censor_region.start_time,
                                     censor_region.end_time)
        return plot

    def _get_time(self):
        # Ensure the properties that depend on time are updated
        return np.arange(len(self.data)) / self.sampling_rate + self.start_time

    def _b_info_fired(self):
        messagebox(str(self))

    def _b_zoom_y_fired(self):
        ui = self.edit_traits(view="zoom_view", kind="modal")
        if ui.result:
            self.plot.range2d.y_range.low = self.ymin
            self.plot.range2d.y_range.high = self.ymax

    def _b_clear_censoring_fired(self):
        for reg in self.censored_regions:
            self.renderer.overlays.remove(reg.viz)
        self.censored_regions = []
        self.plot.request_redraw()
        self.n_censor_intervals = 0

    def _b_add_censor_fired(self):
        self.censored_regions.append(
            CensorRegion(plot=self.renderer,
                         metadata_name=self.__get_metadata_name()))
        self.censored_regions[-1].viz

    buttons = VGroup(
        Item("b_add_censor", show_label=False),
        #Item("b_info"),
        #Item("b_zoom_y"),
        Item("winsorize", enabled_when="winsor_enable"),
        Item("winsor_max",
             enabled_when="winsor_enable",
             format_str="%.4f",
             width=25),
        Item("winsor_min",
             enabled_when="winsor_enable",
             format_str="%.4f",
             width=25),
        Item("b_clear_censoring", show_label=False),
        show_labels=True,
        show_border=True)
    widgets = HSplit(Item('plot', editor=ComponentEditor(), show_label=False),
                     buttons)

    zoom_view = MEAPView(HGroup("ymin", "ymax"),
                         buttons=[OKButton, CancelButton])

    traits_view = MEAPView(widgets, width=500, height=300, resizable=True)