class MEAPhysioExperiment(PhysioExperiment): events = List(Instance(ExpEvent)) physio_files = List(Instance(PhysioFileMEAAnalysis)) def _b_save_fired(self): raise NotImplementedError() def file_constructur(self): return PhysioFileMEAAnalysis def event_constructor(self): return MovingEnsembler def run(self): # read the excel file self.open_xls() # organize all the data from it if not self.events_processed: self._process_events() # Run the pipeline on each acq_file self.extract_mea() def extract_mea(self): stacks = {} for evt in self.conditions: stacks[evt] = {} for pf in self.physio_files: pf.extract_mea() def save_mea_document(self, output_path=""): #self.ensemble_average() if not output_path: output_path = self.output_xls traits_view = MEAPView(HSplit( Group( Group(Item("input_path", label="Input XLS"), Item("output_directory", label="Output Directory"), Item("b_run", show_label=False), Item("b_save", show_label=False), orientation="horizontal"), ), ), resizable=True, win_title="MEA Analysis")
class MEAPGreeter(HasTraits): preproc = Button("Preprocess") analyze = Button("Analyze") configure = Button("Configure MEAP") batch_spreadsheet = Button("Create Batch Spreadsheet") register_dZdt = Button("Batch Register dZ/dt") def _preproc_fired(self): from meap.preprocessing import PreprocessingPipeline preproc=PreprocessingPipeline() preproc.configure_traits() def _analyze_fired(self): from meap.physio_analysis import PhysioExperiment pe = PhysioExperiment() pe.configure_traits() def _register_dZdt_fired(self): from meap.batch_warp_dzdt import BatchGroupRegisterDZDT batch = BatchGroupRegisterDZDT() batch.configure_traits() def _batch_spreadsheet_fired(self): from meap.make_batch_spreadsheet import BatchFileTool bft = BatchFileTool() bft.configure_traits() def _configure_fired(self): print "Not implemented yet!" traits_view=MEAPView( VGroup( Item("preproc"), Item("analyze"), Item("batch_spreadsheet"), Item("register_dZdt"), Item("configure"), show_labels=False ) )
class PreprocessingPipeline(HasTraits): # Holds the set of multiple acq files group_excel_file = File # Which columns contain numeric data? # Holds the unique eveny types for this experiment physio_files = List(Instance(PhysioFilePreprocessor)) # Interactively process the data? interactive = Bool(True) header = List @on_trait_change("group_excel_file") def open_xls(self, fpath=""): if not fpath: fpath = self.group_excel_file logger.info("Loading " + fpath) wb = xlrd.open_workbook(fpath) sheet = wb.sheet_by_index(0) header = [str(item.value) for item in sheet.row(0)] for rnum in range(1, sheet.nrows): rw = sheet.row(rnum) if not len(rw) == len(header): #raise ValueError("not enough information in row %d" % (rnum + 1)) continue #convert to numbers and strings specs = {} for item, hdr in zip(rw, header): if hdr == "in_mri": specs[hdr] = bool(item.value) else: try: specs[hdr] = float(item.value) except Exception: specs[hdr] = str(item.value) self.physio_files.append( PhysioFilePreprocessor(interactive=self.interactive, file=specs['file'], outfile=specs.get('outfile', ''), specs=specs)) self.header = header traits_view = MEAPView(VGroup( VGroup("group_excel_file", show_border=True, label="Excel Files"), HGroup(Item("physio_files", editor=pipeline_table, show_label=False)), ), resizable=True, width=300, height=500)
class CensorRegion(HasTraits): start_time = Float(-1.) end_time = Float(-1.) viz = Instance(RangeSelectionOverlay, transient=True) metadata_name = Str plot = Instance(LinePlot, transient=True) traits_view = MEAPView(Item("start_time"), Item("end_time")) def set_limits(self, start, end): """ """ self.plot.index.metadata[self.metadata_name] = start, end @on_trait_change("plot.index.metadata") def metadata_chaned(self): try: if self.plot.index.metadata[self.metadata_name] is None: return st, end = self.plot.index.metadata[self.metadata_name] except KeyError: return self.start_time = st self.end_time = end def _viz_default(self): self.plot.active_tool = CensorSelection( self.plot, selection_mode="append", append_key=KeySpec("control"), left_button_selects=True, metadata_name=self.metadata_name) rso = RangeSelectionOverlay(component=self.plot, metadata_name=self.metadata_name) self.plot.overlays.append(rso) return rso
class PhysioFilePreprocessor(HasTraits): # Holds the data from physiodata = Instance(PhysioData) interactive = Bool(True) pipeline = Instance(MEAPPipeline) file = DelegatesTo("pipeline") outfile = DelegatesTo("pipeline") specs = Dict importer_kwargs = DelegatesTo("pipeline") mea_saved = Property(Bool, depends_on="file,outfile") def _outfile_default(self): if self.file.endswith(".acq"): return self.file[:-4] + ".mea" return self.file + ".mea" def __init__(self, **traits): super(PhysioFilePreprocessor, self).__init__(**traits) # Some of the data contained in the excel sheet should get # attached to the physiodata file once it is loaded. These # go into the physiodata_kwargs attribute pd_traits = PhysioData().editable_traits() for spec, spec_value in self.specs.iteritems(): if "subject_" + spec in pd_traits: self.importer_kwargs["subject_" + spec] = spec_value def _pipeline_default(self): pipeline = MEAPPipeline() return pipeline def _get_mea_saved(self): if self.file.endswith("mea"): return True return os.path.exists(self.pipeline.outfile) or \ os.path.exists(self.pipeline.outfile + ".mat") @on_trait_change("pipeline.saved") def pipeline_saved(self): """A dumb hack""" old_outfile = self.outfile self.outfile = "asdf" self.outfile = old_outfile pipeline_view = MEAPView( Group(Item("pipeline", style="custom"), show_labels=False))
def default_traits_view(self): plots = [ Item(sig+"_ts", style="custom", show_label=False) for sig in \ sorted(self.physiodata.contents) if not sig.startswith("resp_corr")] if "ecg2" in self.physiodata.contents: widgets = VGroup( Item("qrs_source_signal", label="ECG to use"), Item("start_time"), Item("window_size"), ) else: widgets = VGroup( Item("start_time"), Item("window_size"), ) return MEAPView(VGroup(VGroup(*plots), Group(widgets, orientation="horizontal")), width=1000, height=600, resizable=True, win_title="Aqcknowledge Data", buttons=[OKButton, CancelButton], handler=DataPlotHandler())
class RespirationProcessor(HasTraits): # Parameters resp_polort = DelegatesTo("physiodata") resp_high_freq_cutoff = DelegatesTo("physiodata") time = DelegatesTo("physiodata", "processed_respiration_time") # Data object and arrays to save physiodata = Instance(PhysioData) respiration_cycle = DelegatesTo("physiodata") respiration_amount = DelegatesTo("physiodata") # For visualization plots = Instance(VPlotContainer, transient=True) # Respiration plot resp_plot = Instance(Plot, transient=True) resp_plot_data = Instance(ArrayPlotData, transient=True) resp_signal = Array() # Could be filtered z0 or resp belt raw_resp_signal = Array() # Could be filtered z0 or resp belt resp_signal = Property( Array, depends_on="physiodata.resp_polort,physiodata.resp_high_freq_cutoff") polyfilt_resp = Array() lpfilt_resp = DelegatesTo("physiodata", "processed_respiration_data") resp_inhale_begin_times = DelegatesTo("physiodata") resp_inhale_begin_values = Array resp_exhale_begin_times = DelegatesTo("physiodata") resp_exhale_begin_values = Array z0_plot = Instance(Plot, transient=True) z0_plot_data = Instance(ArrayPlotData, transient=True) raw_z0_signal = Array() resp_corrected_z0 = DelegatesTo("physiodata") dzdt_plot = Instance(Plot, transient=True) dzdt_plot_data = Instance(ArrayPlotData, transient=True) raw_dzdt_signal = Array() resp_corrected_dzdt = DelegatesTo("physiodata") start_time = Range(low=0, high=10000.0, initial=0.0) window_size = Range(low=1.0, high=300.0, initial=30.0) state = Enum("unusable", "z0", "resp", "z0_resp") dirty = Bool(True) b_process = Button(label="Process Respiration") def __init__(self, **traits): super(RespirationProcessor, self).__init__(**traits) resp_inc = "respiration" in self.physiodata.contents z0_inc = "z0" in self.physiodata.contents if z0_inc and resp_inc: self.state = "z0_resp" messagebox("Using both respiration belt and ICG") z0_signal = self.physiodata.z0_data dzdt_signal = self.physiodata.dzdt_data resp_signal = self.physiodata.respiration_data elif resp_inc and not z0_inc: self.state = "resp" messagebox("Only respiration belt data will be used.") resp_signal = self.physiodata.resp_data z0_signal = None dzdt_signal = None elif z0_inc and not resp_inc: self.state = "z0" messagebox("Using only z0 channel to estimate respiration") resp_signal = self.physiodata.z0_data z0_signal = self.physiodata.z0_data dzdt_signal = self.physiodata.dzdt_data self.resp_polort = 1 else: self.state = "unusable" messagebox("No respiration belt or z0 channels found") # Establish the maximum shared length of resp_inhale_begin_values signals = [sig for sig in (resp_signal, z0_signal,dzdt_signal) if \ sig is not None] if len(signals) > 1: minlen = min([len(sig) for sig in signals]) else: minlen = len(signals[0]) # Establish a time array if resp_inc: self.time = TimeSeries(physiodata=self.physiodata, contains="respiration").time[:minlen] else: self.time = TimeSeries(physiodata=self.physiodata, contains="z0").time[:minlen] # Get out the final signals if z0_inc: self.raw_resp_signal = z0_signal[:minlen].copy() self.raw_resp_signal[:50] = self.raw_resp_signal.mean() self.raw_z0_signal = z0_signal[:minlen].copy() self.raw_z0_signal[:50] = self.raw_z0_signal.mean() self.raw_dzdt_signal = dzdt_signal[:minlen].copy() self.raw_dzdt_signal[:50] = self.raw_dzdt_signal.mean() if resp_inc: self.raw_resp_signal = resp_signal[:minlen].copy() # if data already exists, it can't be dirty if self.physiodata.resp_exhale_begin_times.size > 0: self.dirty = False self.on_trait_change(self._parameter_changed, "resp_polort,resp_high_freq_cutoff") @cached_property def _get_resp_signal(self): resp_signal = winsorize(self.raw_resp_signal) return (resp_signal - resp_signal.mean()) / resp_signal.std() def _parameter_changed(self): self.dirty = True def _b_process_fired(self): self.process() def process(self): """ processes the respiration timeseries """ sampling_rate = 1. / (self.time[1] - self.time[0]) # Respiration belts can lose tension over time. # This removes linear trends if self.resp_polort > 0: pfit = legendre_detrend(self.resp_signal, self.resp_polort) if not pfit.shape == self.time.shape: messagebox("Legendre detrend failed") return self.polyfilt_resp = pfit else: self.polyfilt_resp = self.resp_signal lpd = lowpass(self.polyfilt_resp, self.resp_high_freq_cutoff, sampling_rate) if not lpd.shape == self.time.shape: messagebox("lowpass filter failed") return self.lpfilt_resp = (lpd - lpd.mean()) / lpd.std() resp_inhale_begin_indices, resp_exhale_begin_indices = find_peaks( self.lpfilt_resp, maxima=True, minima=True) self.resp_inhale_begin_times = self.time[resp_inhale_begin_indices] self.resp_exhale_begin_times = self.time[resp_exhale_begin_indices] self.resp_corrected_z0 = regress_out(self.raw_z0_signal, self.lpfilt_resp) self.resp_corrected_dzdt = regress_out(self.raw_dzdt_signal, self.lpfilt_resp) # update the scatterplot if we're interactive if self.resp_plot_data is not None: self.resp_plot_data.set_data("inhale_times", self.resp_inhale_begin_times) self.resp_plot_data.set_data( "inhale_values", self.lpfilt_resp[resp_inhale_begin_indices]) self.resp_plot_data.set_data("exhale_times", self.resp_exhale_begin_times) self.resp_plot_data.set_data( "exhale_values", self.lpfilt_resp[resp_exhale_begin_indices]) self.resp_plot_data.set_data("lpfilt_resp", self.lpfilt_resp) self.dzdt_plot_data.set_data("cleaned_data", self.resp_corrected_dzdt) self.z0_plot_data.set_data("cleaned_data", self.resp_corrected_z0) else: print "plot data is none" # Update the respiration cycles cached_pro times = np.concatenate([ np.zeros(1), self.resp_exhale_begin_times, self.resp_inhale_begin_times, self.time[-1, np.newaxis] ]) vals = np.concatenate([ np.zeros(1), np.zeros_like(resp_exhale_begin_indices), 0.5 * np.ones_like(resp_inhale_begin_indices), np.zeros(1) ]) srt = np.argsort(times) terp = interp1d(times[srt], vals[srt]) self.respiration_cycle = terp(self.time) self.respiration_amount = np.abs( np.ediff1d(self.respiration_cycle, to_begin=0)) def _resp_plot_default(self): # Create plotting components self.resp_plot_data = ArrayPlotData( time=self.time, inhale_times=self.resp_inhale_begin_times, inhale_values=self.resp_inhale_begin_values, exhale_times=self.resp_exhale_begin_times, exhale_values=self.resp_exhale_begin_values, filtered_resp=self.resp_signal, lpfilt_resp=self.lpfilt_resp) plot = Plot(self.resp_plot_data) plot.plot(("time", "filtered_resp"), color="blue", line_width=1) plot.plot(("time", "lpfilt_resp"), color="green", line_width=1) # Plot the inhalation peaks plot.plot(("inhale_times", "inhale_values"), type="scatter", marker="square") plot.plot(("exhale_times", "exhale_values"), type="scatter", marker="circle") plot.title = "Respiration" plot.title_position = "right" plot.title_angle = 270 plot.padding = 20 return plot def _z0_plot_default(self): self.z0_plot_data = ArrayPlotData(time=self.time, raw_data=self.raw_z0_signal, cleaned_data=self.resp_corrected_z0) plot = Plot(self.z0_plot_data) plot.plot(("time", "raw_data"), color="blue", line_width=1) plot.plot(("time", "cleaned_data"), color="green", line_width=1) plot.title = "z0" plot.title_position = "right" plot.title_angle = 270 plot.padding = 20 return plot def _dzdt_plot_default(self): """ Creates a plot of the ecg_ts data and the signals derived during the Pan Tomkins algorithm """ # Create plotting components self.dzdt_plot_data = ArrayPlotData( time=self.time, raw_data=self.raw_dzdt_signal, cleaned_data=self.resp_corrected_dzdt) plot = Plot(self.dzdt_plot_data) plot.plot(("time", "raw_data"), color="blue", line_width=1) plot.plot(("time", "cleaned_data"), color="green", line_width=1) plot.title = "dZ/dt" plot.title_position = "right" plot.title_angle = 270 plot.padding = 20 return plot def _plots_default(self): plots_to_include = () if self.state in ("z0_resp", "z0"): self.index_range = self.resp_plot.index_range self.dzdt_plot.index_range = self.index_range self.z0_plot.index_range = self.index_range plots_to_include = [self.resp_plot, self.z0_plot, self.dzdt_plot] elif self.state == "resp": self.index_range = self.resp_plot.index_range plots_to_include = [self.resp_plot] return VPlotContainer(*plots_to_include) @on_trait_change("window_size,start_time") def update_plot_range(self): self.resp_plot.index_range.high = self.start_time + self.window_size self.resp_plot.index_range.low = self.start_time proc_params_group = Group(Group( Item("resp_polort"), Item("resp_high_freq_cutoff"), Item("b_process", show_label=False, enabled_when="state != unusable and dirty"), label="Resp processing options", show_border=True, orientation="vertical", springy=True), orientation="horizontal") plot_group = Group( Group(Item("plots", editor=ComponentEditor(), width=800, height=700), show_labels=False), Item("start_time"), Item("window_size")) traits_view = MEAPView( VSplit( plot_group, proc_params_group, ), resizable=True, win_title="Process Respiration Data", )
class GroupRegisterDZDT(HasTraits): physiodata = Instance(PhysioData) global_ensemble = Instance(GlobalEnsembleAveragedHeartBeat) srvf_lambda = DelegatesTo("physiodata") srvf_max_karcher_iterations = DelegatesTo("physiodata") srvf_update_min = DelegatesTo("physiodata") srvf_karcher_mean_subset_size = DelegatesTo("physiodata") dzdt_srvf_karcher_mean = DelegatesTo("physiodata") dzdt_karcher_mean = DelegatesTo("physiodata") dzdt_warping_functions = DelegatesTo("physiodata") srvf_use_moving_ensembled = DelegatesTo("physiodata") bspline_before_warping = DelegatesTo("physiodata") dzdt_functions_to_warp = DelegatesTo("physiodata") # Configures the slice of time to be registered srvf_t_min = DelegatesTo("physiodata") srvf_t_max = DelegatesTo("physiodata") dzdt_karcher_mean_time = DelegatesTo("physiodata") dzdt_mask = Array # Holds indices of beats used to calculate initial Karcher mean dzdt_karcher_mean_inputs = DelegatesTo("physiodata") dzdt_karcher_mean_over_iterations = DelegatesTo("physiodata") dzdt_num_inputs_to_group_warping = DelegatesTo("physiodata") srvf_iteration_distances = DelegatesTo("physiodata") srvf_iteration_energy = DelegatesTo("physiodata") # For calculating multiple modes all_beats_registered_to_initial = Bool(False) all_beats_registered_to_mode = Bool(False) n_modes = DelegatesTo("physiodata") max_kmeans_iterations = DelegatesTo("physiodata") mode_dzdt_karcher_means = DelegatesTo("physiodata") mode_cluster_assignment = DelegatesTo("physiodata") mode_dzdt_srvf_karcher_means = DelegatesTo("physiodata") # graphics items karcher_plot = Instance(Plot, transient=True) registration_plot = Instance(HPlotContainer, transient=True) karcher_plotdata = Instance(ArrayPlotData, transient=True) # Buttons b_calculate_karcher_mean = Button(label="Calculate Karcher Mean") b_align_all_beats = Button(label="Warp all") b_find_modes = Button(label="Discover Modes") b_edit_modes = Button(label="Score Modes") interactive = Bool(False) # Holds the karcher modes mode_beat_train = Instance(ModeKarcherBeatTrain) edit_listening = Bool(False, desc="If true, update_plots is called" " when a beat gets hand labeled") def __init__(self, **traits): super(GroupRegisterDZDT, self).__init__(**traits) # Set the initial path to whatever's in the physiodata file logger.info("Initializing dZdt registration") # Is there already a Karcher mean in the physiodata? self.karcher_mean_calculated = self.dzdt_srvf_karcher_mean.size > 0 \ and self.dzdt_karcher_mean.size > 0 # Process dZdt data before srvf analysis self._update_original_functions() self.n_functions = self.dzdt_functions_to_warp.shape[0] self.n_samples = self.dzdt_functions_to_warp.shape[1] self._init_warps() self.select_new_samples() def _mode_beat_train_default(self): logger.info("creating default mea_beat_train") assert self.physiodata is not None mkbt = ModeKarcherBeatTrain(physiodata=self.physiodata) return mkbt def _init_warps(self): """ For loading warps from mea.mat """ if self.dzdt_warping_functions.shape[0] == \ self.dzdt_functions_to_warp.shape[0]: self.all_beats_registered_to_mode = True self._forward_warp_beats() def _global_ensemble_default(self): return GlobalEnsembleAveragedHeartBeat(physiodata=self.physiodata) def _karcher_plot_default(self): """ Instead of defining these in __init__, only construct the plots when a ui is requested """ self.interactive = True unreg_mean = self.global_ensemble.dzdt_signal # Temporarily fill in the karcher mean if not self.karcher_mean_calculated: karcher_mean = unreg_mean[self.dzdt_mask] else: karcher_mean = self.dzdt_karcher_mean self.karcher_plotdata = ArrayPlotData( time=self.global_ensemble.dzdt_time, unregistered_mean=self.global_ensemble.dzdt_signal, karcher_mean=karcher_mean, karcher_time=self.dzdt_karcher_mean_time) karcher_plot = Plot(self.karcher_plotdata) karcher_plot.plot(("time", "unregistered_mean"), line_width=1, color="lightblue") line_plot = karcher_plot.plot(("karcher_time", "karcher_mean"), line_width=3, color="blue")[0] # Create an overlay tool karcher_plot.datasources['karcher_time'].sort_order = "ascending" return karcher_plot def _registration_plot_default(self): """ Instead of defining these in __init__, only construct the plots when a ui is requested """ unreg_mean = self.global_ensemble.dzdt_signal # Temporarily fill in the karcher mean if not self.karcher_mean_calculated: karcher_mean = unreg_mean[self.dzdt_mask] else: karcher_mean = self.dzdt_karcher_mean self.single_registration_plotdata = ArrayPlotData( karcher_time=self.dzdt_karcher_mean_time, karcher_mean=karcher_mean, registered_func=karcher_mean) self.single_plot = Plot(self.single_registration_plotdata) self.single_plot.plot(("karcher_time", "karcher_mean"), line_width=2, color="blue") self.single_plot.plot(("karcher_time", "registered_func"), line_width=2, color="maroon") if self.all_beats_registered_to_initial or \ self.all_beats_registered_to_mode: image_data = self.registered_functions.copy() else: image_data = self.dzdt_functions_to_warp.copy() # Create a plot of all the functions registered or not self.all_registration_plotdata = ArrayPlotData(image_data=image_data) self.all_plot = Plot(self.all_registration_plotdata) self.all_plot.img_plot("image_data", colormap=jet) return HPlotContainer(self.single_plot, self.all_plot) def _forward_warp_beats(self): """ Create registered beats to plot, since they're not stored """ pass logger.info("Applying warps to functions for plotting") # Re-create gam gam = self.dzdt_warping_functions - self.srvf_t_min gam = gam / (self.srvf_t_max - self.srvf_t_min) aligned_functions = np.zeros_like(self.dzdt_functions_to_warp) t = self.dzdt_karcher_mean_time for k in range(self.n_functions): aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t, self.dzdt_functions_to_warp[k]) self.registered_functions = aligned_functions @on_trait_change("dzdt_num_inputs_to_group_warping") def select_new_samples(self): nbeats = self.dzdt_functions_to_warp.shape[0] nsamps = min(self.dzdt_num_inputs_to_group_warping, nbeats) self.dzdt_karcher_mean_inputs = np.random.choice(nbeats, size=nsamps, replace=False) @on_trait_change( ("physiodata.srvf_lambda, " "physiodata.srvf_update_min, physiodata.srvf_use_moving_ensembled, " "physiodata.srvf_karcher_mean_subset_size, " "physiodata.bspline_before_warping")) def params_edited(self): self.dirty = True self.karcher_mean_calculated = False self.all_beats_registered_to_initial = False self.all_beats_registered_to_mode = False self._update_original_functions() def _update_original_functions(self): logger.info("updating time slice and functions to register") # Get the time relative to R dzdt_time = np.arange(self.physiodata.dzdt_matrix.shape[1],dtype=np.float) \ - self.physiodata.dzdt_pre_peak self.dzdt_mask = (dzdt_time >= self.srvf_t_min) * (dzdt_time <= self.srvf_t_max) srvf_time = dzdt_time[self.dzdt_mask] self.dzdt_karcher_mean_time = srvf_time # Extract corresponding data self.dzdt_functions_to_warp = self.physiodata.mea_dzdt_matrix[ : , self.dzdt_mask] if \ self.srvf_use_moving_ensembled else self.physiodata.dzdt_matrix[ : , self.dzdt_mask] if self.bspline_before_warping: logger.info("Smoothing inputs with B-Splines") self.dzdt_functions_to_warp = np.row_stack([ UnivariateSpline( self.dzdt_karcher_mean_time, func, s=0.05)(self.dzdt_karcher_mean_time) \ for func in self.dzdt_functions_to_warp] ) if self.interactive: self.all_registration_plotdata.set_data( "image_data", self.dzdt_functions_to_warp.copy()) self.all_plot.request_redraw() def _b_calculate_karcher_mean_fired(self): self.calculate_karcher_mean() def calculate_karcher_mean(self): """ Calculates an initial Karcher Mean. """ reg_prob = RegistrationProblem( self.dzdt_functions_to_warp[self.dzdt_karcher_mean_inputs].T, sample_times=self.dzdt_karcher_mean_time, max_karcher_iterations=self.srvf_max_karcher_iterations, lambda_value=self.srvf_lambda, update_min=self.srvf_update_min) reg_prob.run_registration_parallel() reg_problem = reg_prob self.dzdt_karcher_mean = reg_problem.function_karcher_mean self.dzdt_srvf_karcher_mean = reg_problem.srvf_karcher_mean self.karcher_mean_calculated = True # Update plots if this is interactive if self.interactive: self.karcher_plotdata.set_data("karcher_mean", self.dzdt_karcher_mean) self.karcher_plot.request_redraw() self.single_registration_plotdata.set_data("karcher_mean", self.dzdt_karcher_mean) self.single_plot.request_redraw() self.rp = reg_problem def _b_align_all_beats_fired(self): self.align_all_beats_to_initial() def _b_find_modes_fired(self): self.detect_modes() def detect_modes(self): """ Uses the SRD-based clustering method described in Kurtek 2017 """ if not self.karcher_mean_calculated: fail("Must calculate an initial Karcher mean first") return if not self.all_beats_registered_to_initial: fail("All beats must be registered to the initial Karcher mean") return # Calculate the SRDs dzdt_functions_to_warp = self.dzdt_functions_to_warp.T warps = self.dzdt_warping_functions.T wmax = warps.max() wmin = warps.min() warps = (warps - wmin) / (wmax - wmin) densities = np.diff(warps, axis=0) SRDs = np.sqrt(densities) # pairwise distances srd_pairwise = pairwise_distances(SRDs.T, metric=fisher_rao_dist) tri = srd_pairwise[np.triu_indices_from(srd_pairwise, 1)] linkage = complete(tri) # Performs an iteration of k-means def cluster_karcher_means(initial_assignments): cluster_means = {} cluster_ids = np.unique(initial_assignments).tolist() warping_functions = np.zeros_like(SRDs) # Calculate a Karcher mean for each cluster for cluster_id in cluster_ids: print "Cluster ID:", cluster_id cluster_id_mask = initial_assignments == cluster_id cluster_srds = SRDs[:, cluster_id_mask] # If there is only a single SRD in this cluster, it is the mean if cluster_id_mask.sum() == 1: cluster_means[cluster_id] = cluster_srds continue # Run group registration to get Karcher mean cluster_reg = RegistrationProblem( cluster_srds, sample_times=np.arange(SRDs.shape[0], dtype=np.float), max_karcher_iterations=self.srvf_max_karcher_iterations, lambda_value=self.srvf_lambda, update_min=self.srvf_update_min) cluster_reg.run_registration_parallel() cluster_means[cluster_id] = cluster_reg.function_karcher_mean warping_functions[:, cluster_id_mask] = cluster_reg.mean_to_orig_warps # Scale the cluster Karcher means so the FR distance works scaled_cluster_means = {} for k, v in cluster_means.iteritems(): scaled_cluster_means[k] = rescale_cluster_mean(v) # There are now k cluster means, which is closest for each SRD? # Also save its distance to its cluster's Karcher mean srd_cluster_assignments, srd_cluster_distances = get_closest_mean( SRDs, scaled_cluster_means) return srd_cluster_assignments, srd_cluster_distances, scaled_cluster_means, warping_functions # Iterate until assignments stabilize last_assignments = fcluster(linkage, self.n_modes, criterion="maxclust") stabilized = False n_iters = 0 old_assignments = [last_assignments] old_means = [] while not stabilized and n_iters < self.max_kmeans_iterations: logger.info("Iteration %d", n_iters) assignments, distances, cluster_means, warping_funcs = cluster_karcher_means( last_assignments) stabilized = np.all(last_assignments == assignments) last_assignments = assignments.copy() old_assignments.append(last_assignments) old_means.append(cluster_means) n_iters += 1 # Finalize the clusters by aligning all the functions to the cluster mean # Iterate until assignments stabilize cluster_means = {} cluster_ids = np.unique(assignments) warping_functions = np.zeros_like(dzdt_functions_to_warp) self.registered_functions = np.zeros_like(self.dzdt_functions_to_warp) # Calculate a Karcher mean for each cluster for cluster_id in cluster_ids: cluster_id_mask = assignments == cluster_id cluster_funcs = self.dzdt_functions_to_warp.T[:, cluster_id_mask] # If there is only a single SRD in this cluster, it is the mean if cluster_id_mask.sum() == 1: cluster_means[cluster_id] = cluster_funcs continue # Run group registration to get Karcher mean cluster_reg = RegistrationProblem( cluster_funcs, sample_times=self.dzdt_karcher_mean_time, max_karcher_iterations=self.srvf_max_karcher_iterations, lambda_value=self.srvf_lambda, update_min=self.srvf_update_min) cluster_reg.run_registration_parallel() cluster_means[cluster_id] = cluster_reg.function_karcher_mean warping_functions[:, cluster_id_mask] = cluster_reg.mean_to_orig_warps self.registered_functions[ cluster_id_mask] = cluster_reg.registered_functions.T # Save the warps to the modes as the final warping functions self.dzdt_warping_functions = warping_functions.T \ * (self.srvf_t_max - self.srvf_t_min) \ + self.srvf_t_min # re-order the means cluster_ids = sorted(cluster_means.keys()) final_assignments = np.zeros_like(assignments) final_modes = [] for final_id, orig_id in enumerate(cluster_ids): final_assignments[assignments == orig_id] = final_id final_modes.append(cluster_means[orig_id].squeeze()) self.mode_dzdt_karcher_means = np.row_stack(final_modes) self.mode_cluster_assignment = final_assignments self.all_beats_registered_to_mode = True def align_all_beats_to_initial(self): if not self.karcher_mean_calculated: logger.warn("Calculate Karcher mean first") return logger.info("Aligning all beats to the Karcher Mean") if self.interactive: progress = ProgressDialog( title="ICG Warping", min=0, max=len(self.physiodata.peak_times), show_time=True, message="Warping dZ/dt to Karcher Mean...") progress.open() template_func = self.dzdt_srvf_karcher_mean normed_template_func = template_func / np.linalg.norm(template_func) fy, fx = np.gradient(self.dzdt_functions_to_warp.T, 1, 1) srvf_functions = (fy / np.sqrt(np.abs(fy) + eps)).T gam = np.zeros(self.dzdt_functions_to_warp.shape, dtype=np.float) aligned_functions = self.dzdt_functions_to_warp.copy() logger.info("aligned_functions %d", id(aligned_functions)) logger.info("self.dzdt_functions_to_warp %d", id(self.dzdt_functions_to_warp)) t = self.dzdt_karcher_mean_time for k in range(self.n_functions): q_c = srvf_functions[k] / np.linalg.norm(srvf_functions[k]) G, T = dp(normed_template_func, t, q_c, t, t, t, self.srvf_lambda) gam0 = np.interp(self.dzdt_karcher_mean_time, T, G) gam[k] = (gam0 - gam0[0]) / (gam0[-1] - gam0[0]) # change scale aligned_functions[k] = np.interp((t[-1] - t[0]) * gam[k] + t[0], t, self.dzdt_functions_to_warp[k]) if self.interactive: # Update the image plot self.all_registration_plotdata.set_data( "image_data", aligned_functions) # Update the registration plot self.single_registration_plotdata.set_data( "registered_func", aligned_functions[k]) self.single_plot.request_redraw() (cont, skip) = progress.update(k) self.registered_functions = aligned_functions.copy() self.dzdt_warping_functions = gam * ( self.srvf_t_max - self.srvf_t_min) + \ self.srvf_t_min if self.interactive: progress.update(k + 1) self.all_beats_registered_to_initial = True def _b_edit_modes_fired(self): self.mode_beat_train.edit_traits() def _point_plots_default(self): """ Instead of defining these in __init__, only construct the plots when a ui is requested """ self.point_plotdata = ArrayPlotData( peak_times=self.physiodata.peak_times.flatten(), x_times=self.physiodata.x_indices - self.physiodata.dzdt_pre_peak, lvet=self.physiodata.lvet, pep=self.physiodata.pep) container = VPlotContainer(resizable="hv", bgcolor="lightgray", fill_padding=True, padding=10) for sig in ("pep", "x_times", "lvet"): temp_plt = Plot(self.point_plotdata) temp_plt.plot(("peak_times", sig), line_width=2) temp_plt.title = sig container.add(temp_plt) container.padding_top = 10 return container mean_widgets = VGroup(VGroup(Item("karcher_plot", editor=ComponentEditor(), show_label=False), Item("registration_plot", editor=ComponentEditor(), show_label=False), show_labels=False), Item("srvf_use_moving_ensembled", label="Use Moving Ensembled dZ/dt"), Item("bspline_before_warping", label="B Spline smoothing"), Item("srvf_t_min", label="Epoch Start Time"), Item("srvf_t_max", label="Epoch End Time"), Item("srvf_lambda", label="Lambda Value"), Item("dzdt_num_inputs_to_group_warping", label="Template uses N beats"), Item("srvf_max_karcher_iterations", label="Max Karcher Iterations"), HGroup( Item("b_calculate_karcher_mean", label="Step 1:", enabled_when="dirty"), Item("b_align_all_beats", label="Step 2:", enabled_when="karcher_mean_calculated")), label="Initial Karcher Mean") mode_widgets = VGroup( Item("n_modes", label="Number of Modes/Clusters"), Item("max_kmeans_iterations"), Item("b_find_modes", show_label=False, enabled_when="all_beats_registered_to_initial"), Item("b_edit_modes", show_label=False, enabled_when="all_beats_registered_to_mode")) traits_view = MEAPView(HSplit(mean_widgets, mode_widgets), resizable=True, win_title="ICG Warping Tools", width=800, height=700, buttons=[OKButton, CancelButton])
class MEAPConfig(HasTraits): # Parameters for point marking apply_ecg_smoothing = CBool(True) ecg_smoothing_window_len = Int(5) # NOTE: Changed in 1.1 from 20 apply_imp_smoothing = CBool(True) imp_smoothing_window_len = Int(40) apply_bp_smoothing = CBool(True) bp_smoothing_window_len = Int(80) # Parameters for waveform extraction peak_window = CInt(80) #Range(low=5,high=400, value= 200) ecg_pre_peak = CInt(300) #Range(low=50, high=500, value=300) ecg_post_peak = CInt(400) #Range(low=100, high=700, value=400) dzdt_pre_peak = CInt(300) #Range(50, 500, value=300) dzdt_post_peak = CInt(700) #Range(100, 1000, value=700) doppler_pre_peak = CInt(300) #Range(50, 500, value=300) doppler_post_peak = CInt(700) #Range(100, 1000, value=700) bp_pre_peak = CInt(300) #Range(50, 500, value=300) bp_post_peak = CInt(1000) #Range(100, 2500, value=1200) systolic_pre_peak = CInt(300) #Range(50, 500, value=300) systolic_post_peak = CInt(1000) #Range(100, 2500, value=1200) diastolic_pre_peak = CInt(300) #Range(50, 500, value=300) diastolic_post_peak = CInt(1000) #Range(100, 2500, value=1200) stroke_volume_equation = Enum("Kubicek","Sramek-Bernstein") extraction_group = Group( Group( Item("peak_window"), Item("ecg_pre_peak"), Item("ecg_post_peak"), Item("dzdt_pre_peak"), Item("dzdt_post_peak"), Item("bp_pre_peak"), Item("bp_post_peak"), Item("stroke_volume_equation"), orientation="vertical", show_border=True, label = "Waveform Extraction Parameters", springy=True ), Group( Item("mhd_bandpass_min"), Item("mhd_bandpass_max"), Item("mhd_smoothing_window_len"), Item("mhd_smoothing_window"), Item("qrs_to_mhd_ratio"), Item("combined_smoothing_window_len"), Item("combined_smoothing_window"), enabled_when="subject_in_mri" ) ) # Respiration analysis parameters process_respiration = CBool(True) resp_polort = CInt(7) resp_high_freq_cutoff = CFloat(0.35) regress_out_resp = Bool(False) # parameters for processing the raw data before PT detecting # MRI-specific subject_in_mri = CBool(False) peak_detection_algorithm = Enum("Pan Tomkins 83", "Multisignal", "ECG2") # PanTomkins algorithm bandpass_min = Range(low=1,high=200,initial=5,value=5) bandpass_max = Range(low=1,high=200,initial=15,value=15) smoothing_window_len = Range(low=10,high=1000,initial=100,value=100) smoothing_window = Enum(SMOOTHING_WINDOWS) pt_adjust = Range(low=-2.,high=2.,value=0.00) peak_threshold=CFloat apply_filter = CBool(True) apply_diff_sq = CBool(True) apply_smooth_ma = CBool(True) pt_params_group = Group( Item("apply_filter"), Item("bandpass_min"),#editor=RangeEditor(enter_set=True)), Item("bandpass_max"),#editor=RangeEditor(enter_set=True)), Item("smoothing_window_len"),#editor=RangeEditor(enter_set=True)), Item("smoothing_window"), Item("pt_adjust"),#editor=RangeEditor(enter_set=True)), Item("apply_diff_sq"), Item("subject_in_mri"), label="R peak detection options", show_border=True, orientation="vertical", springy=True ) # Second signal heartbeat detection use_secondary_heartbeat = CBool(False) secondary_heartbeat = Enum("dzdt", "pulse_ox", "bp") secondary_heartbeat_pre_msec = CInt(400) secondary_heartbeat_abs = CBool(True) secondary_heartbeat_window = Enum(SMOOTHING_WINDOWS) secondary_heartbeat_window_len = CInt(801) secondary_heartbeat_n_likelihood_bins = CInt(15) # ECG2 parameters use_ECG2 = CBool(False) qrs_signal_source = Enum("ecg", "ecg2") ecg2_weight = Range(low=0., high=1.,value=0.5) # Moving Ensembling Parameters mea_window_type = Enum("Seconds","Beats") mea_n_neighbors = Range(low=0, high=60, value=8) mea_window_secs = Range(low=1., high=60, value=15.) mea_exp_power = Enum(2,3,4,5,6) mea_func_name = Enum("linear","exponential","flat") mea_weight_direction = Enum("symmetric", "before", "after") mea_smooth_hr = CBool(True) use_trimmed_co = CBool(True) #bpoint classifier parameters bpoint_classifier_pre_point_msec = CInt(20) bpoint_classifier_post_point_msec = CInt(20) bpoint_classifier_sample_every_n_msec = CInt(1) bpoint_classifier_false_distance_min = CInt(5) bpoint_classifier_use_bpoint_prior = CBool(True) bpoint_classifier_include_derivative = CBool(True) # Doppler D point config db_point_type = Enum("min", "max") db_point_window_len = CInt(20) dx_point_type = Enum("min", "max") dx_point_window_len = CInt(20) # SRVF-warping parameters srvf_lambda = CFloat(0.0) srvf_max_karcher_iterations = CInt(15) srvf_update_min = CFloat(0.01) srvf_karcher_mean_subset_size = CInt(30) srvf_multi_mode_variance_cutoff = CFloat(0.3) srvf_use_moving_ensembled = CBool(False) dzdt_num_inputs_to_group_warping = CInt(25) srvf_t_min = CInt(0) # Time in msec relative to R srvf_t_max = CInt(150) # Also time in msec relative to R bspline_before_warping = CBool(True) n_modes = CInt(5) max_kmeans_iterations = CInt(5)
class Importer(HasTraits): path = File channels = List(Instance(Channel)) mapping_txt = File() mapper = Dict() # --- Subject information subject_age = CFloat(20.) subject_gender = Enum("M","F") subject_weight = CFloat(135.,label="Weight (lbs)") subject_height_ft = CInt(5,label="Height (ft)", desc="Subject's height in feet") subject_height_in = CInt(10,label = "Height (in)", desc="Subject's height in inches") subject_electrode_distance_front = CFloat(32, label="Impedance electrode distance (front)") subject_electrode_distance_back = CFloat(32, label="Impedance electrode distance (back)") subject_electrode_distance_right = CFloat(32, label="Impedance electrode distance (right)") subject_electrode_distance_left = CFloat(32, label="Impedance electrode distance (left)") subject_in_mri = CBool(False) subject_control_base_impedance = CFloat(0.,label="Control Imprdance", desc="If in MRI, store the z0 value from outside the MRI") subject_resp_max = CFloat(100.,label="Respiration circumference max (cm)") subject_resp_min = CFloat(0.,label="Respiration circumference min (cm)") def _channel_map_default(self): return ChannelMapper() def _mapping_txt_changed(self): if not os.path.exists(self.mapping_txt): return with open(self.mapping_txt) as f: lines = [l.strip() for l in f] self.mapper = {} for line in lines: strip = line.split("->") if not len(strip) == 2: logger.info("Unable to parse line: %s", line) continue map_from, map_to = [l.strip() for l in strip] if map_to in SUPPORTED_SIGNALS: self.mapper[map_from] = map_to logger.info("mapping %s to %s",map_from,map_to) else: logger.info("Unrecognized channel: %s", map_to) traits_view = MEAPView( Group( Item("channels", editor=channel_table), label="Specify channel contents", show_border=True, show_labels=False ), Group( Item("subject_age"), Item("subject_gender"), Item("subject_height_ft"), Item("subject_height_in"), Item("subject_weight"), Item("subject_electrode_distance_front"), Item("subject_electrode_distance_back"), Item("subject_electrode_distance_left"), Item("subject_electrode_distance_right"), Item("subject_resp_max"),Item("subject_resp_min"), Item("subject_in_mri"),Item('subject_control_base_impedance'), label="Participant Measurements" ), buttons=[OKButton, CancelButton], resizable=True, win_title="Import Data", ) def guess_channel_contents(self): mapped = set([v for k,v in self.mapper.iteritems()]) for chan in self.channels: if chan.name in self.mapper: chan.contains = self.mapper[chan.name] continue cname = chan.name.lower() if "magnitude" in cname and not "z0" in mapped: chan.contains = "z0" elif "ecg" in cname and not "ecg" in mapped: chan.contains = "ecg" elif "dz/dt" in cname or "derivative" in cname and \ not "dzdt" in mapped: chan.contains = "dzdt" elif "diastolic" in cname and not "diastolic" in mapped: chan.contains = "diastolic" elif "systolic" in cname and not "systolic" in mapped: chan.contains = "systolic" elif "blood pressure" in cname or "bp" in cname and \ not "bp" in mapped: chan.contains = "bp" elif "resp" in cname and not "respiration" in mapped: chan.contains = "respiration" elif "trigger" in cname and not "mri_trigger" in mapped: chan.contains = "mri_trigger" elif "stimulus" in cname and not "event" in mapped: chan.contains = "event" def subject_data(self): return dict(subject_age = self.subject_age, subject_gender = self.subject_gender, subject_weight = self.subject_weight, subject_height_ft = self.subject_height_ft, subject_height_in = self.subject_height_in, subject_electrode_distance_front = self.subject_electrode_distance_front, subject_electrode_distance_back = self.subject_electrode_distance_back , subject_electrode_distance_left = self.subject_electrode_distance_left , subject_electrode_distance_right = self.subject_electrode_distance_right , subject_resp_max = self.subject_resp_max, subject_resp_min = self.subject_resp_min, subject_in_mri = self.subject_in_mri ) def get_physiodata(self): raise NotImplementedError("Must be overwritten by subclass")
class BatchFileTool(HasTraits): # For saving outputs file_suffix = Str("_finished.mea.mat") input_file_extension = Enum(".mea.mat", ".acq", ".mat") overwrite = Bool(False) input_directory = Directory() output_directory = Directory() files = List(Instance(FileToProcess)) spreadsheet_file = File(exists=False) b_save = Button("Save Spreadsheet") def _input_file_extension_changed(self): self._input_directory_changed() def _input_directory_changed(self): potential_files = glob(self.input_directory + "/*" + self.input_file_extension) potential_files = [f for f in potential_files if not \ f.endswith(self.file_suffix) ] # Check if the output already exists def make_output_file(input_file): return input_file[:-len(self.input_file_extension )] + self.file_suffix self.files = [ FileToProcess(input_file = f, output_file=make_output_file(f)) \ for f in potential_files ] # If no output directory is set, use the input directory if self.output_directory == '': self.output_directory = self.input_directory def _b_save_fired(self): def get_row(): return { "file": "", "outfile": "", "weight": "", "height_ft": "", "weight": "", "electrode_distance_front": "", "electrode_distance_back": "", "electrode_distance_left": "", "electrode_distance_right": "", "resp_max": "", "resp_min": "", "in_mri": "", "control_base_impedance": "" } rows = [] for f in self.files: row = get_row() row['file'] = f.input_file row['outfile'] = f.output_file rows.append(row) df = pd.DataFrame(rows) logger.info("Writing spreadsheet to %s", self.spreadsheet_file) df.to_excel(self.spreadsheet_file, index=False) mean_widgets = VGroup(Item("input_file_extension"), Item("input_directory"), Item("output_directory"), Item("file_suffix"), Item("spreadsheet_file"), Item("b_save", show_label=False)) traits_view = MEAPView(HSplit( Item("files", editor=files_table, show_label=False), mean_widgets), resizable=True, win_title="Create Batch Spreadsheet", width=800, height=700, buttons=[OKButton, CancelButton])
class BatchGroupRegisterDZDT(HasTraits): # Dummy physiodata to get defaults physiodata = Instance(PhysioData) # Parameters for SRVF registration srvf_lambda = DelegatesTo("physiodata") srvf_max_karcher_iterations = DelegatesTo("physiodata") srvf_update_min = DelegatesTo("physiodata") srvf_karcher_mean_subset_size = DelegatesTo("physiodata") srvf_use_moving_ensembled = DelegatesTo("physiodata") bspline_before_warping = DelegatesTo("physiodata") dzdt_num_inputs_to_group_warping = DelegatesTo("physiodata") srvf_t_min = DelegatesTo("physiodata") srvf_t_max = DelegatesTo("physiodata") n_modes = DelegatesTo("physiodata") # For saving outputs num_cores = Int(1) file_suffix = Str() overwrite = Bool(False) input_directory = Directory() output_directory = Directory() files = List(Instance(FileToProcess)) b_run = Button(label="Run") def __init__(self, **traits): super(BatchGroupRegisterDZDT, self).__init__(**traits) self.num_cores = multiprocessing.cpu_count() def _physiodata_default(self): return PhysioData() @on_trait_change(("physiodata.srvf_karcher_iterations, " "physiodata.srvf_use_moving_ensembled, " "physiodata.srvf_karcher_mean_subset_size, " "physiodata.bspline_before_warping")) def params_edited(self): self._input_directory_changed() def _num_cores_changed(self): num_cores = multiprocessing.cpu_count() if self.num_cores < 1 or self.num_cores > num_cores: self.num_cores = multiprocessing.cpu_count() def _configure_io(self): potential_files = glob(self.input_directory + "/*mea.mat") potential_files = [f for f in potential_files if not \ f.endswith("_aligned.mea.mat") ] # Check if the output already exists def make_output_file(input_file): suffix = os.path.split(os.path.abspath(input_file))[1] suffix = suffix[:-len(".mea.mat")] + self.file_suffix return self.output_directory + "/" + suffix self.files = [ FileToProcess(input_file = f, output_file=make_output_file(f)) \ for f in potential_files ] def _output_directory_changed(self): self._configure_io() def _file_suffix_changed(self): self._configure_io() def _input_directory_changed(self): self._configure_io() def _b_run_fired(self): files_to_run = [f for f in self.files if not f.finished] logger.info("Processing %d files using %d cpus", len(files_to_run), self.num_cores) pool = multiprocessing.Pool(self.num_cores) arglist = [ (f.input_file, f.output_file, self.srvf_lambda, self.srvf_max_karcher_iterations, self.srvf_update_min, self.srvf_karcher_mean_subset_size, self.srvf_use_moving_ensembled, self.bspline_before_warping, self.dzdt_num_inputs_to_group_warping, self.srvf_t_min, self.srvf_t_max, self.n_modes) for f in files_to_run ] pool.map(process_physio_file, arglist) mean_widgets = VGroup( Item("input_directory"), Item("output_directory"), Item("file_suffix"), Item("srvf_use_moving_ensembled", label="Use Moving Ensembled dZ/dt"), Item("bspline_before_warping", label="B Spline smoothing"), Item("srvf_lambda", label="Lambda Value"), Item("dzdt_num_inputs_to_group_warping", label="Template uses N beats"), Item("srvf_max_karcher_iterations", label="Max Iterations"), Item("n_modes"), Item("srvf_t_min"), Item("srvf_t_max"), Item("num_cores"), Item("b_run", show_label=False)) traits_view = MEAPView(HSplit( Item("files", editor=files_table, show_label=False), mean_widgets), resizable=True, win_title="Batch Warp dZ/dt", width=800, height=700, buttons=[OKButton, CancelButton])
class SubjectInfo(HasTraits): physiodata = Instance(PhysioData) subject_age = DelegatesTo("physiodata") subject_gender = DelegatesTo("physiodata") subject_weight = DelegatesTo("physiodata") subject_height_ft = DelegatesTo("physiodata") subject_height_in = DelegatesTo("physiodata") subject_electrode_distance_front = DelegatesTo("physiodata") subject_electrode_distance_back = DelegatesTo("physiodata") subject_electrode_distance_right = DelegatesTo("physiodata") subject_electrode_distance_left = DelegatesTo("physiodata") subject_resp_max = DelegatesTo("physiodata") subject_resp_min = DelegatesTo("physiodata") subject_in_mri = DelegatesTo("physiodata") subject_control_base_impedance = DelegatesTo("physiodata") traits_view = MEAPView( VGroup(Item("subject_age"), Item("subject_gender"), Item("subject_height_ft"), Item("subject_height_in"), Item("subject_weight"), Item("subject_electrode_distance_front"), Item("subject_electrode_distance_back"), Item("subject_electrode_distance_left"), Item("subject_electrode_distance_right"), Item("subject_resp_max"), Item("subject_resp_min"), Item("subject_in_mri"), Item('subject_control_base_impedance'), label="Participant Measurements"), buttons=[OKButton], resizable=True, win_title="Subject Info", )
class MEAPPipeline(HasTraits): physiodata = Instance(PhysioData) file = File outfile = File mapping_txt = File importer = Instance(Importer) importer_kwargs = Dict b_import = Button(label="Import data") b_subject_info = Button(label="Subject Info") b_inspect = Button(label="Inspect data") b_resp = Button(label="Process Resp") b_detect = Button(label="Detect QRS Complexes") b_custom_points = Button(label="Label Waveform Points") b_moving_ens = Button(label="Compute Moving Ensembles") b_register = Button(label="Register dZ/dt") b_fmri_tool = Button(label="Process fMRI") b_load_experiment = Button(label="Load Design File") b_save = Button(label="save .meap file") b_clear_mem = Button(label="Clear Memory") interactive = Bool(False) saved = Bool(False) peaks_detected = Bool(False) global_ensemble_marked = Bool(False) def _b_clear_mem_fired(self): self.physiodata = None self.global_ensemble_marked = False self.peaks_detected = False def import_data(self): logger.info("Loading %s", self.file) for k, v in self.importer_kwargs.iteritems(): logger.info("Using attr '%s' from excel file", k) if not os.path.exists(self.file): fail("unable to open file %s" % self.file, interactive=self.interactive) return elif self.file.endswith(".acq"): importer = AcqImporter(path=str(self.file), mapping_txt=self.mapping_txt, **self.importer_kwargs) elif self.file.endswith(".mea.mat"): pd = load_from_disk(self.file) self.physiodata = pd if self.physiodata.using_hand_marked_point_priors: self.global_ensemble_marked = True if self.physiodata.peak_indices.size > 0: self.peaks_detected = True return elif self.file.endswith(".mat"): importer = MatfileImporter(path=str(self.file), **self.importer_kwargs) else: return ui = importer.edit_traits(kind="livemodal") if not ui.result: logger.info("Import cancelled") return self.physiodata = importer.get_physiodata() def _b_import_fired(self): self.import_data() def _b_resp_fired(self): RespirationProcessor(physiodata=self.physiodata).edit_traits( kind="livemodal") def _b_subject_info_fired(self): SubjectInfo(physiodata=self.physiodata).edit_traits(kind="livemodal") def _b_custom_points_fired(self): ge = GlobalEnsembleAveragedHeartBeat(physiodata=self.physiodata) if not self.physiodata.using_hand_marked_point_priors: ge.mark_points() ge.edit_traits(kind="livemodal") self.physiodata.using_hand_marked_point_priors = True self.global_ensemble_marked = True def _b_inspect_fired(self): DataPlot(physiodata=self.physiodata).edit_traits() def _b_detect_fired(self): detector = PanTomkinsDetector(physiodata=self.physiodata) ui = detector.edit_traits(kind="livemodal") if ui.result: self.peaks_detected = True def _b_moving_ens_fired(self): MovingEnsembler(physiodata=self.physiodata).edit_traits() def _b_register_fired(self): GroupRegisterDZDT(physiodata=self.physiodata).edit_traits() def _b_save_fired(self): print "writing", self.outfile if os.path.exists(self.outfile): logger.warn("%s already exists", self.outfile) self.physiodata.save(self.outfile) self.saved = True logger.info("saved %s", self.outfile) def _b_fmri_tool_fired(self): FMRITool(physiodata=self.physiodata).edit_traits() traits_view = MEAPView( VGroup( VGroup(Item("file"), Item("outfile")), VGroup( spring, Item("b_import", show_label=False, enabled_when= "file.endswith('.acq') or file.endswith('.mat')"), Item("b_inspect", enabled_when="physiodata is not None"), Item("b_subject_info", enabled_when="physiodata is not None"), Item("b_resp", enabled_when="physiodata is not None"), Item("b_detect", enabled_when="physiodata is not None"), Item("b_custom_points", enabled_when="peaks_detected"), Item("b_moving_ens", enabled_when="global_ensemble_marked"), Item("b_register", enabled_when="peaks_detected"), Item("b_fmri_tool", enabled_when= "physiodata.processed_respiration_data.size > 0"), Item("b_save", enabled_when= "outfile.endswith('.mea') or outfile.endswith('.mea.mat')" ), Item("b_clear_mem"), spring, show_labels=False)))
class PanTomkinsDetector(HasTraits): # Holds the data Source physiodata = Instance(PhysioData) ecg_ts = Instance(TimeSeries) dirty = Bool(False) # For visualization plot = Instance(Plot, transient=True) plot_data = Instance(ArrayPlotData, transient=True) image_plot = Instance(Plot, transient=True) image_plot_data = Instance(ArrayPlotData, transient=True) image_selection_tool = Instance(ImageRangeSelector) image_plot_selected = DelegatesTo("image_selection_tool") peak_vis = Instance(ScatterPlot, transient=True) scatter = Instance(ScatterPlot, transient=True) start_time = Range(low=0, high=100000.0, initial=0.0) window_size = Range(low=1.0, high=100000.0, initial=30.0) peak_editor = Instance(PeakPickingTool, transient=True) show_censor_regions = Bool(True) # parameters for processing the raw data before PT detecting bandpass_min = DelegatesTo("physiodata") bandpass_max = DelegatesTo("physiodata") smoothing_window_len = DelegatesTo("physiodata") smoothing_window = DelegatesTo("physiodata") pt_adjust = DelegatesTo("physiodata") peak_threshold = DelegatesTo("physiodata") apply_filter = DelegatesTo("physiodata") apply_diff_sq = DelegatesTo("physiodata") apply_smooth_ma = DelegatesTo("physiodata") peak_window = DelegatesTo("physiodata") qrs_source_signal = DelegatesTo("physiodata") qrs_power_signal = Property(Array, depends_on=qrs_power_signal_dependencies) # Use a secondary signal to limit the search range use_secondary_heartbeat = DelegatesTo("physiodata") secondary_heartbeat = DelegatesTo("physiodata") secondary_heartbeat_abs = DelegatesTo("physiodata") secondary_heartbeat_pre_msec = DelegatesTo("physiodata") secondary_heartbeat_window = DelegatesTo("physiodata") secondary_heartbeat_window_len = DelegatesTo("physiodata") secondary_heartbeat_n_likelihood_bins = DelegatesTo("physiodata") aux_signal = Property(Array) #, depends_on=aux_signal_dependencies) # Secondary ECG signal can_use_ecg2 = Bool(False) use_ECG2 = DelegatesTo("physiodata") ecg2_weight = DelegatesTo("physiodata") # The results of the peak search peak_indices = DelegatesTo("physiodata") peak_times = DelegatesTo("physiodata") peak_values = Array dne_peak_indices = DelegatesTo("physiodata") dne_peak_times = DelegatesTo("physiodata") dne_peak_values = Array # Holds the timeseries for the signal vs noise thresholds thr_times = Array thr_vals = Array # for pulling out data for aggregating ecg_ts, dzdt_ts and bp beats ecg_matrix = DelegatesTo("physiodata") # UI elements b_detect = Button(label="Detect QRS", transient=True) b_shape = Button(label="Shape-Based Tuning", transient=True) def __init__(self, **traits): super(PanTomkinsDetector, self).__init__(**traits) self.ecg_ts = TimeSeries(physiodata=self.physiodata, contains="ecg") # relevant data to be plotted self.ecg_time = self.ecg_ts.time self.can_use_ecg2 = "ecg2" in self.physiodata.contents # Which ECG signal to use? if self.qrs_source_signal == "ecg" or not self.can_use_ecg2: self.ecg_signal = normalize(self.ecg_ts.data) if self.can_use_ecg2: self.ecg2_signal = normalize(self.physiodata.ecg2_data) else: self.ecg2_signal = None else: self.ecg_signal = normalize(self.physiodata.ecg2_data) self.ecg2_signal = normalize(self.physiodata.ecg_data) self.thr_times = self.ecg_time[np.array([0, -1])] self.censored_intervals = self.physiodata.censored_intervals # graphics containers self.aux_window_graphics = [] self.censored_region_graphics = [] # If peaks are already available in the mea.mat, # they have already been marked if len(self.peak_times): self.dirty = False self.peak_values = self.ecg_signal[self.peak_indices] if self.dne_peak_indices is not None and len(self.dne_peak_indices): self.dne_peak_values = self.ecg_signal[self.dne_peak_indices] def _b_shape_fired(self): self.shape_based_tuning() def shape_based_tuning(self): logger.info("Running shape-based tuning") qrs_stack = peak_stack(self.peak_indices, self.qrs_power_signal, pre_msec=300, post_msec=700, sampling_rate=self.physiodata.ecg_sampling_rate) def _b_detect_fired(self): self.detect() def _show_censor_regions_changed(self): for crg in self.censored_region_graphics: crg.visible = self.show_censor_regions self.plot.request_redraw() @cached_property def _get_qrs_power_signal(self): if self.apply_filter: filtered_ecg = normalize( bandpass(self.ecg_signal, self.bandpass_min, self.bandpass_max, self.ecg_ts.sampling_rate)) else: filtered_ecg = self.ecg_signal # Differentiate and square the signal? if self.apply_diff_sq: # Differentiate the signal and square it diff_sq = np.ediff1d(filtered_ecg, to_begin=0)**2 else: diff_sq = filtered_ecg # If we're to apply a Moving Average smoothing if self.apply_smooth_ma: # MA smoothing smooth_ma = smooth(diff_sq, window_len=self.smoothing_window_len, window=self.smoothing_window) else: smooth_ma = diff_sq if not self.use_ECG2: # for visualization purposes return normalize(smooth_ma) logger.info("Using 2nd ECG signal") # Use secondary ECG signal and combine QRS power if self.apply_filter: filtered_ecg2 = normalize( bandpass(self.ecg2_signal, self.bandpass_min, self.bandpass_max, self.ecg_ts.sampling_rate)) else: filtered_ecg2 = self.ecg2_signal if self.apply_diff_sq: diff_sq2 = np.ediff1d(filtered_ecg2, to_begin=0)**2 else: diff_sq2 = filtered_ecg2 if self.apply_smooth_ma: smooth_ma2 = smooth(diff_sq2, window_len=self.smoothing_window_len, window=self.smoothing_window) else: smooth_ma2 = diff_sq2 return normalize(((1 - self.ecg2_weight) * smooth_ma + self.ecg2_weight * smooth_ma2)**2) def _get_aux_signal(self): if not self.use_secondary_heartbeat: return np.array([]) sig = getattr(self.physiodata, self.secondary_heartbeat + "_data") if self.secondary_heartbeat_abs: sig = np.abs(sig) if self.secondary_heartbeat_window_len > 0: sig = smooth(sig, window_len=self.secondary_heartbeat_window_len, window=self.secondary_heartbeat_window) return normalize(sig) @on_trait_change(algorithm_parameters) def params_edited(self): self.dirty = True def update_aux_window_graphics(self): if not hasattr(self, "ecg_trace") or not self.use_secondary_heartbeat: return #remove the old graphics for aux_window in self.aux_window_graphics: del self.ecg_trace.index.metadata[aux_window.metadata_name] self.ecg_trace.overlays.remove(aux_window) # add the new graphics for n, (start, end) in enumerate(self.aux_windows / 1000.): window_key = "aux%03d" % n self.aux_window_graphics.append( AuxSignalWindow(component=self.aux_trace, metadata_name=window_key)) self.ecg_trace.overlays.append(self.aux_window_graphics[-1]) self.ecg_trace.index.metadata[window_key] = start, end def get_aux_windows(self): aux_peaks = find_peaks(self.aux_signal) aux_pre_peak = aux_peaks - self.secondary_heartbeat_pre_msec aux_windows = np.column_stack([aux_pre_peak, aux_peaks]) return aux_windows def detect(self): """ Implementation of the Pan Tomkins QRS detector """ logger.info("Beginning QRS Detection") t0 = time.time() # The original paper used a different method for finding peaks smoothdiff = normalize(np.ediff1d(self.qrs_power_signal, to_begin=0)) peaks = find_peaks(smoothdiff) peak_amps = self.qrs_power_signal[peaks] # Part 2: getting rid of useless peaks # ==================================== # There are lots of small, irrelevant peaks that need to # be discarded. # 2a) remove peaks occurring in censored intervals censor_peak_mask = censor_peak_times( self.ecg_ts. censored_regions, #TODO: switch to physiodata.censored_intervals self.ecg_time[peaks]) n_removed_peaks = peaks.shape[0] - censor_peak_mask.sum() logger.info("%d/%d potential peaks outside %d censored intervals", n_removed_peaks, peaks.shape[0], len(self.ecg_ts.censored_regions)) peaks = peaks[censor_peak_mask] peak_amps = peak_amps[censor_peak_mask] # 2b) if a second signal is used, make sure the ecg peaks are # near aux signal peaks if self.use_secondary_heartbeat: self.aux_windows = self.get_aux_windows() aux_mask = times_contained_in(peaks, self.aux_windows) logger.info("Using secondary signal") logger.info("%d aux peaks detected", self.aux_windows.shape[0]) logger.info("%d/%d peaks contained in second signal window", aux_mask.sum(), peaks.shape[0]) peaks = peaks[aux_mask] peak_amps = peak_amps[aux_mask] # 2c) use otsu's method to find a cutoff value peak_amp_thr = threshold_otsu(peak_amps) + self.pt_adjust otsu_mask = peak_amps > peak_amp_thr logger.info("otsu threshold: %.7f", peak_amp_thr) logger.info("%d/%d peaks survive", otsu_mask.sum(), peaks.shape[0]) peaks = peaks[otsu_mask] peak_amps = peak_amps[otsu_mask] # 3) Make sure there is only one peak in each secondary window # this is accomplished by estimating the distribution of peak # times relative to aux window starts if self.use_secondary_heartbeat: npeaks = peaks.shape[0] # obtain a distribution of times and amplitudes rel_peak_times = np.zeros_like(peaks) window_subsets = [] for start, end in self.aux_windows: mask = np.flatnonzero((peaks >= start) & (peaks <= end)) if len(mask) == 0: continue window_subsets.append(mask) rel_peak_times[mask] = peaks[mask] - start # how common are relative peak times relative to window starts densities, bins = np.histogram( rel_peak_times, range=(0, self.secondary_heartbeat_pre_msec), bins=self.secondary_heartbeat_n_likelihood_bins, density=True) likelihoods = densities[np.clip( np.digitize(rel_peak_times, bins) - 1, 0, self.secondary_heartbeat_n_likelihood_bins - 1)] _peaks = [] # Pull out the maximal peak for subset in window_subsets: # If there's only a single peak contained, no need for math if len(subset) == 1: _peaks.append(peaks[subset[0]]) continue _peaks.append(peaks[subset[np.argmax(likelihoods[subset])]]) peaks = np.array(_peaks) logger.info("Only 1 peak per aux window allowed:" + \ " %d/%d peaks remaining",peaks.shape[0],npeaks) # Check that no two peaks are too close: peak_diffs = np.ediff1d(peaks, to_begin=500) peaks = peaks[peak_diffs > 200] # Cutoff is 300BPM # Stack the peaks and see if the original data has a higher value raw_stack = peak_stack(peaks, self.ecg_signal, pre_msec=self.peak_window, post_msec=self.peak_window, sampling_rate=self.ecg_ts.sampling_rate) adj_factors = np.argmax(raw_stack, axis=1) - self.peak_window peaks = peaks + adj_factors self.peak_indices = peaks self.peak_values = self.ecg_signal[peaks] self.peak_times = self.ecg_time[peaks] self.thr_vals = np.array([peak_amp_thr] * 2) t1 = time.time() # update the scatterplot if we're interactive if self.plot_data is not None: self.plot_data.set_data("peak_times", self.peak_times) self.plot_data.set_data("peak_values", self.peak_values) self.plot_data.set_data("qrs_power", self.qrs_power_signal) self.plot_data.set_data("aux_signal", self.aux_signal) self.plot_data.set_data("thr_vals", self.thr_vals) self.plot_data.set_data("thr_times", self.thr_vals) self.plot_data.set_data( "thr_times", np.array([self.ecg_time[0], self.ecg_time[-1]])), self.update_aux_window_graphics() self.plot.request_redraw() self.image_plot_data.set_data("imagedata", self.ecg_matrix) self.image_plot.request_redraw() else: print "plot data is none" logger.info("found %d QRS complexes in %.3f seconds", len(self.peak_indices), t1 - t0) self.dirty = False def _plot_default(self): """ Creates a plot of the ecg_ts data and the signals derived during the Pan Tomkins algorithm """ # Create plotting components self.plot_data = ArrayPlotData( time=self.ecg_time, raw_ecg=self.ecg_signal, qrs_power=self.qrs_power_signal, aux_signal=self.aux_signal, # Ensembleable peaks peak_times=self.peak_times, peak_values=self.peak_values, # Non-ensembleable, useful for HR, HRV dne_peak_times=self.dne_peak_times, dne_peak_values=self.dne_peak_values, # Threshold slider bar thr_times=np.array([self.ecg_time[0], self.ecg_time[-1]]), thr_vals=np.array([0, 0])) plot = Plot(self.plot_data, use_backbuffer=True) self.aux_trace = plot.plot(("time", "aux_signal"), color="purple", alpha=0.6, line_width=2)[0] self.qrs_power_trace = plot.plot(("time", "qrs_power"), color="blue", alpha=0.8)[0] self.ecg_trace = plot.plot(("time", "raw_ecg"), color="red", line_width=2)[0] # Load the censor regions and add them to the plot for n, (start, end) in enumerate(self.censored_intervals): censor_key = "censor%03d" % n self.censored_region_graphics.append( StaticCensoredInterval(component=self.qrs_power_trace, metadata_name=censor_key)) self.ecg_trace.overlays.append(self.censored_region_graphics[-1]) self.ecg_trace.index.metadata[censor_key] = start, end # Line showing self.threshold_trace = plot.plot(("thr_times", "thr_vals"), color="black", line_width=3)[0] # Plot for plausible peaks self.ppeak_vis = plot.plot(("dne_peak_times", "dne_peak_values"), type="scatter", marker="diamond")[0] # Make a scatter plot where you can edit the peaks self.peak_vis = plot.plot(("peak_times", "peak_values"), type="scatter")[0] self.scatter = plot.components[-1] self.peak_editor = PeakPickingTool(self.scatter) self.scatter.tools.append(self.peak_editor) # when the user adds or removes a point, automatically extract self.on_trait_event(self._change_peaks, "peak_editor.done_selecting") self.scatter.overlays.append( PeakPickingOverlay(component=self.scatter)) return plot # TODO: have this update other variables in physiodata: hand_labeled, mea, etc def _delete_peaks(self, peaks_to_delete): pass def _add_peaks(self, peaks_to_add): pass # End TODO def _change_peaks(self): if self.dirty: messagebox("You shouldn't edit peaks until you've run Detect QRS") interval = self.peak_editor.selection if interval is None: return mode = self.peak_editor.selection_purpose logger.info("PeakPickingTool entered %s mode over %s", mode, str(interval)) if mode == "delete": # Do it for real peaks ok_peaks = np.logical_not( np.logical_and(self.peak_times > interval[0], self.peak_times < interval[1])) self.peak_indices = self.peak_indices[ok_peaks] self.peak_values = self.peak_values[ok_peaks] self.peak_times = self.peak_times[ok_peaks] # And do it for ok_dne_peaks = np.logical_not( np.logical_and(self.dne_peak_times > interval[0], self.dne_peak_times < interval[1])) self.dne_peak_indices = self.dne_peak_indices[ok_dne_peaks] self.dne_peak_values = self.dne_peak_values[ok_dne_peaks] self.dne_peak_times = self.dne_peak_times[ok_dne_peaks] else: # Find the signal contained in the selection sig = np.logical_and(self.ecg_time > interval[0], self.ecg_time < interval[1]) sig_inds = np.flatnonzero(sig) selected_sig = self.ecg_signal[sig_inds] # find the peak in the selected region peak_ind = sig_inds[0] + np.argmax(selected_sig) # If we're in add peak mode, always make sure # that only unique peaks get added: if mode == "add": real_peaks = np.sort( np.unique(self.peak_indices.tolist() + [peak_ind])) self.peak_indices = real_peaks self.peak_values = self.ecg_signal[real_peaks] self.peak_times = self.ecg_time[real_peaks] else: real_dne_peaks = np.sort( np.unique(self.dne_peak_indices.tolist() + [peak_ind])) self.dne_peak_indices = real_dne_peaks self.dne_peak_values = self.ecg_signal[real_dne_peaks] self.dne_peak_times = self.ecg_time[real_dne_peaks] # update the scatterplot if we're interactive if not self.plot_data is None: self.plot_data.set_data("peak_times", self.peak_times) self.plot_data.set_data("peak_values", self.peak_values) self.plot_data.set_data("dne_peak_times", self.dne_peak_times) self.plot_data.set_data("dne_peak_values", self.dne_peak_values) self.image_plot_data.set_data("imagedata", self.ecg_matrix) self.image_plot.request_redraw() @on_trait_change("window_size,start_time") def update_plot_range(self): self.plot.index_range.high = self.start_time + self.window_size self.plot.index_range.low = self.start_time @on_trait_change("image_plot_selected") def snap_to_image_selection(self): if not self.peak_times.size > 0: return tmin = self.peak_times[int(self.image_selection_tool.ymin)] - 2. tmax = self.peak_times[int(self.image_selection_tool.ymax)] + 2. logger.info("selection tool sends data to (%.2f, %.2f)", tmin, tmax) self.plot.index_range.low = tmin - 2. self.plot.index_range.high = tmax + 2. def _image_plot_default(self): # for image plot img = self.ecg_matrix if self.ecg_matrix.size == 0: img = np.zeros((100, 100)) self.image_plot_data = ArrayPlotData(imagedata=img) plot = Plot(self.image_plot_data) self.image_selection_tool = ImageRangeSelector(component=plot, tool_mode="range", axis="value", always_on=True) plot.img_plot("imagedata", colormap=jet, name="plot1", origin="bottom left")[0] plot.overlays.append(self.image_selection_tool) return plot detection_params_group = VGroup( HGroup( Group( VGroup(Item("use_secondary_heartbeat"), Item("peak_window"), Item("apply_filter"), Item("bandpass_min"), Item("bandpass_max"), Item("smoothing_window_len"), Item("smoothing_window"), Item("pt_adjust"), Item("apply_diff_sq"), label="Pan Tomkins"), VGroup( # Items for MultiSignal detection Item("secondary_heartbeat"), Item("secondary_heartbeat_pre_msec"), Item("secondary_heartbeat_abs"), Item("secondary_heartbeat_window"), Item("secondary_heartbeat_window_len"), label="MultiSignal Detection", enabled_when="use_secondary_heartbeat"), VGroup( # Items for two ECG Signals Item("use_ECG2"), Item("ecg2_weight"), label="Multiple ECG", enabled_when="can_use_ecg2"), label="Beat Detection Options", layout="tabbed", show_border=True, springy=True), Item("image_plot", editor=ComponentEditor(), width=300), show_labels=False), Item("b_detect", enabled_when="dirty"), Item("b_shape"), show_labels=False) plot_group = Group( Group(Item("plot", editor=ComponentEditor(), width=800), show_labels=False), Item("start_time"), Item("window_size")) traits_view = MEAPView( VSplit( plot_group, detection_params_group, #orientation="vertical" ), resizable=True, buttons=[OKButton, CancelButton], title="Detect Heart Beats")
class MEAPTimeseries(HasTraits): physiodata = Instance(PhysioData) plot = Instance(Plot, transient=True) plotdata = Instance(ArrayPlotData, transient=True) selection = Instance(BeatSelection, transient=True) signal = Str selected_beats = Array index_datasourse = Instance(DataRange1D) marker_size = Array metadata_name = Str selected = Event selected_range = Tuple traits_view = MEAPView(Group(Item('plot', editor=ComponentEditor(size=size), show_label=False), orientation="vertical"), resizable=True, win_title=title) def __init__(self, **traits): super(MEAPTimeseries, self).__init__(**traits) self.metadata_name = self.signal + "_selection" def _selection_changed(self): selection = self.selection.selection logger.info("%s selection changed to %s", self.signal, str(selection)) if selection is None or len(selection) == 0: return self.selected_range = selection self.selected = True def _plot_default(self): plot = Plot(self.plotdata) if self.signal in ("tpr", "co", "sv"): rc_signal = "resp_corrected_" + self.signal if rc_signal in self.plotdata.arrays and self.plotdata.arrays[ rc_signal].size > 0: # Plot the resp-uncorrected version underneath plot.plot(("peak_times", self.signal), type="line", color="purple") signal = rc_signal signal = self.signal elif self.signal == "hr": plot.plot(("peak_times", self.signal), type="line", color="purple") signal = "mea_hr" else: signal = self.signal # Create the plot plot.plot(("peak_times", signal), type="line", color="blue") plot.plot(("peak_times", signal, "beat_type"), type="cmap_scatter", color_mapper=jet, name="my_plot", marker="circle", border_visible=False, outline_color="transparent", bg_color="white", index_sort="ascending", marker_size=3, fill_alpha=0.8) # Tweak some of the plot properties plot.title = self.signal #"Scatter Plot With Lasso Selection" plot.title_position = "right" plot.title_angle = 270 plot.line_width = 1 plot.padding = 20 # Right now, some of the tools are a little invasive, and we need the # actual ScatterPlot object to give to them my_plot = plot.plots["my_plot"][0] # Attach some tools to the plot self.selection = BeatSelection(component=my_plot) my_plot.active_tool = self.selection selection_overlay = RangeSelectionOverlay(component=my_plot) my_plot.overlays.append(selection_overlay) # Set up the trait handler for the selection self.selection.on_trait_change(self._selection_changed, 'selection_completed') return plot
class FMRITool(HasTraits): physiodata = Instance(PhysioData) processed_respiration_data = DelegatesTo("physiodata") processed_respiration_time = DelegatesTo("physiodata") slice_times_txt = Str slice_times = Array slice_times_matrix = Array acquisition_type = Enum("Coronal", "Saggittal", "Axial") dicom_acquisition_direction = Enum("LPS+", "RAS+") unsure_about_direction = Bool(False) correction_type = Enum(values="possible_corrections") possible_corrections = List(["Slicewise Regression"]) respiration_expansion_order = Int(4) drift_model_order = Int(6) add_back_intensity = Bool(True) cardiac_expansion_order = Int(3) interaction_expansion_order = Int(1) physio_design_file = File # Local arrays that have been restricted to during-scanning times r_times = Array # Beat times occurring during scanning tr_onsets = Array # TR times resp_times = Array resp_signal = Array # Nifti Stuff fmri_file = File fmri_mask = File denoised_output = File regression_output = File nuisance_output_prefix = Str b_calculate = Button(label="RUN") write_4d_nuisance = Bool(False) write_4d_nuisance_for_each_modality = Bool(False) output_units = Enum("Percent Signal Change", "Raw BOLD") fmri_tr = Float interactive = Bool(False) missing_triggers = Int(0) missing_volumes = Int(0) traits_view = MEAPView(VGroup( VGroup(Item("fmri_file"), Item("fmri_mask"), Item("slice_times_txt"), Item("acquisition_type"), Item("dicom_acquisition_direction"), Item("unsure_about_direction"), Item("denoised_output"), Item("regression_output"), Item("nuisance_output_prefix"), Item("write_4d_nuisance"), Item("write_4d_nuisance_for_each_modality"), show_border=True, label="fMRI I/O"), VGroup(Item("cardiac_expansion_order"), Item("respiration_expansion_order"), Item("interaction_expansion_order"), Item("correction_type"), Item("drift_model_order"), Item("add_back_intensity"), Item("output_units"), Item("physio_design_file"), show_border=True, label="RETROICOR")), Item("b_calculate", show_label=False), win_title="RETROICOR") def _load_slice_timings(self): try: _slices = np.loadtxt(self.slice_times_txt) except Exception, e: messagebox("Unable to load slice timings:\n%s" % e) raise ValueError # Check if it's in the mid-TR format if np.any(_slices < 0): logger.info("Using the mid-TR slice convention") if self.fmri_tr <= 0: raise ValueError("fMRI file has to be set before using mid-TR") _slices = (_slices + 0.5) * self.fmri_tr # Check if the times are in msec or if np.any(_slices) > 10: logger.info("Converting slice times from msec to sec") _slices = _slices / 1000. else: logger.info("Assuming slice timings are in seconds") # Slice times should be one row per TR, and one column # per slice. self.slice_times = _slices logger.info("Using slice timings of %s", str(self.slice_times))
class TimeSeries(HasTraits): # Holds data name = Str contains = Enum(SUPPORTED_SIGNALS) plot = Instance(Plot, transient=True) data = Array time = Property(Array, depends_on=["start_time", "sampling_rate", "data"]) # Plotting options visible = Bool(True) ymax = Float() ymin = Float() line_type = marker_trait line_color = ColorTrait("blue") plot_type = Enum("line", "scatter") censored_regions = List(Instance(CensorRegion)) b_add_censor = Button(label="Add CensorRegion", transient=True) b_zoom_y = Button(label="Zoom y", transient=True) b_info = Button(label="Info", transient=True) b_clear_censoring = Button(label="Clear Censoring", transient=True) renderer = Instance(LinePlot, transient=True) # For the winsorizing steps winsor_swap = Array winsor_min = Float(0.005) winsor_max = Float(0.005) winsorize = Bool(False) winsor_enable = Bool(True) def __init__(self, **traits): """ Class to represent data collected over time """ super(TimeSeries, self).__init__(**traits) if not self.contains in self.physiodata.contents: raise ValueError("Signal not found in data") self.name = self.contains # First, check whether it's already winsorized winsorize_trait = self.name + "_winsorize" if getattr(self.physiodata, winsorize_trait): self.winsor_enable = False self.winsor_min = getattr(self.physiodata, self.name + "_winsor_min") self.winsor_max = getattr(self.physiodata, self.name + "_winsor_max") # Load the actual data self.data = getattr(self.physiodata, self.contains + "_data") self.winsor_swap = self.data.copy() self.sampling_rate = getattr(self.physiodata, self.contains + "_sampling_rate") self.sampling_rate_unit = getattr( self.physiodata, self.contains + "_sampling_rate_unit") self.start_time = getattr(self.physiodata, self.contains + "_start_time") """ The censored regions are loaded from physiodata INITIALLY. from that point on the censored regions are accessed from physiodata's """ self.line_color = colors[self.contains] self.n_censor_intervals = 0 for (start, end), source in zip(self.physiodata.censored_intervals, self.physiodata.censoring_sources): if str(source) == self.contains: self.censored_regions.append( CensorRegion(start_time=start, end_time=end, metadata_name=self.__get_metadata_name())) def __get_metadata_name(self): name = self.contains + "%03d" % self.n_censor_intervals self.n_censor_intervals += 1 return name def _winsorize_changed(self): if self.winsorize: logger.info("Winsorizing %s with limits=(%.5f%.5f)", self.name, self.winsor_min, self.winsor_max) # Take the original data and replace it with the winsorized version self.data = np.array( winsorize(self.winsor_swap, limits=(self.winsor_min, self.winsor_max))) setattr(self.physiodata, self.name + "_winsor_min", self.winsor_min) setattr(self.physiodata, self.name + "_winsor_max", self.winsor_max) else: logger.info("Restoring %s to its original data", self.name) self.data = self.winsor_swap.copy() setattr(self.physiodata, self.contains + "_data", self.data) setattr(self.physiodata, self.name + "_winsorize", self.winsorize) self.plot.range2d.y_range.low = self.data.min() self.plot.range2d.y_range.high = self.data.max() self.plot.request_redraw() def __str__(self): descr = "Timeseries: %s\n" % self.name descr += "-" * len(descr) + "\n\t" + \ "\n\t".join([ "Sampling rate: %.4f%s" % (self.sampling_rate,self.sampling_rate_unit), "N samples: %d" % self.data.shape[0], "N censored intervals: %d" % len(self.censored_regions), "Start time: %.3f" % self.start_time ]) return descr def _plot_default(self): # Create plotting components plotdata = ArrayPlotData(time=self.time, data=self.data) plot = Plot(plotdata) self.renderer = plot.plot(("time", "data"), color=self.line_color)[0] plot.title = self.name plot.title_position = "right" plot.title_angle = 270 plot.line_width = 1 plot.padding = 25 plot.width = 400 # Load the censor regions and add them to the plot for censor_region in self.censored_regions: # Make a censor region on the Timeseries object censor_region.plot = self.renderer censor_region.viz censor_region.set_limits(censor_region.start_time, censor_region.end_time) return plot def _get_time(self): # Ensure the properties that depend on time are updated return np.arange(len(self.data)) / self.sampling_rate + self.start_time def _b_info_fired(self): messagebox(str(self)) def _b_zoom_y_fired(self): ui = self.edit_traits(view="zoom_view", kind="modal") if ui.result: self.plot.range2d.y_range.low = self.ymin self.plot.range2d.y_range.high = self.ymax def _b_clear_censoring_fired(self): for reg in self.censored_regions: self.renderer.overlays.remove(reg.viz) self.censored_regions = [] self.plot.request_redraw() self.n_censor_intervals = 0 def _b_add_censor_fired(self): self.censored_regions.append( CensorRegion(plot=self.renderer, metadata_name=self.__get_metadata_name())) self.censored_regions[-1].viz buttons = VGroup( Item("b_add_censor", show_label=False), #Item("b_info"), #Item("b_zoom_y"), Item("winsorize", enabled_when="winsor_enable"), Item("winsor_max", enabled_when="winsor_enable", format_str="%.4f", width=25), Item("winsor_min", enabled_when="winsor_enable", format_str="%.4f", width=25), Item("b_clear_censoring", show_label=False), show_labels=True, show_border=True) widgets = HSplit(Item('plot', editor=ComponentEditor(), show_label=False), buttons) zoom_view = MEAPView(HGroup("ymin", "ymax"), buttons=[OKButton, CancelButton]) traits_view = MEAPView(widgets, width=500, height=300, resizable=True)