Beispiel #1
0
class MEAPhysioExperiment(PhysioExperiment):
    events = List(Instance(ExpEvent))
    physio_files = List(Instance(PhysioFileMEAAnalysis))

    def _b_save_fired(self):
        raise NotImplementedError()

    def file_constructur(self):
        return PhysioFileMEAAnalysis

    def event_constructor(self):
        return MovingEnsembler

    def run(self):
        # read the excel file
        self.open_xls()
        # organize all the data from it
        if not self.events_processed:
            self._process_events()
        # Run the pipeline on each acq_file
        self.extract_mea()

    def extract_mea(self):
        stacks = {}
        for evt in self.conditions:
            stacks[evt] = {}

        for pf in self.physio_files:
            pf.extract_mea()

    def save_mea_document(self, output_path=""):
        #self.ensemble_average()
        if not output_path:
            output_path = self.output_xls

    traits_view = MEAPView(HSplit(
        Group(
            Group(Item("input_path", label="Input XLS"),
                  Item("output_directory", label="Output Directory"),
                  Item("b_run", show_label=False),
                  Item("b_save", show_label=False),
                  orientation="horizontal"), ), ),
                           resizable=True,
                           win_title="MEA Analysis")
Beispiel #2
0
class PhysioFilePreprocessor(HasTraits):
    # Holds the data from
    physiodata = Instance(PhysioData)
    interactive = Bool(True)
    pipeline = Instance(MEAPPipeline)
    file = DelegatesTo("pipeline")
    outfile = DelegatesTo("pipeline")
    specs = Dict
    importer_kwargs = DelegatesTo("pipeline")
    mea_saved = Property(Bool, depends_on="file,outfile")

    def _outfile_default(self):
        if self.file.endswith(".acq"):
            return self.file[:-4] + ".mea"
        return self.file + ".mea"

    def __init__(self, **traits):
        super(PhysioFilePreprocessor, self).__init__(**traits)

        # Some of the data contained in the excel sheet should get
        # attached to the physiodata file once it is loaded. These
        # go into the physiodata_kwargs attribute
        pd_traits = PhysioData().editable_traits()
        for spec, spec_value in self.specs.iteritems():
            if "subject_" + spec in pd_traits:
                self.importer_kwargs["subject_" + spec] = spec_value

    def _pipeline_default(self):
        pipeline = MEAPPipeline()
        return pipeline

    def _get_mea_saved(self):
        if self.file.endswith("mea"):
            return True
        return os.path.exists(self.pipeline.outfile) or \
               os.path.exists(self.pipeline.outfile + ".mat")

    @on_trait_change("pipeline.saved")
    def pipeline_saved(self):
        """A dumb hack"""
        old_outfile = self.outfile
        self.outfile = "asdf"
        self.outfile = old_outfile

    pipeline_view = MEAPView(
        Group(Item("pipeline", style="custom"), show_labels=False))
Beispiel #3
0
 def default_traits_view(self):
     plots = [ Item(sig+"_ts", style="custom", show_label=False) for sig in \
                     sorted(self.physiodata.contents) if not sig.startswith("resp_corr")]
     if "ecg2" in self.physiodata.contents:
         widgets = VGroup(
             Item("qrs_source_signal", label="ECG to use"),
             Item("start_time"),
             Item("window_size"),
         )
     else:
         widgets = VGroup(
             Item("start_time"),
             Item("window_size"),
         )
     return MEAPView(VGroup(VGroup(*plots),
                            Group(widgets, orientation="horizontal")),
                     width=1000,
                     height=600,
                     resizable=True,
                     win_title="Aqcknowledge Data",
                     buttons=[OKButton, CancelButton],
                     handler=DataPlotHandler())
Beispiel #4
0
class PanTomkinsDetector(HasTraits):
    # Holds the data Source
    physiodata = Instance(PhysioData)
    ecg_ts = Instance(TimeSeries)
    dirty = Bool(False)

    # For visualization
    plot = Instance(Plot, transient=True)
    plot_data = Instance(ArrayPlotData, transient=True)
    image_plot = Instance(Plot, transient=True)
    image_plot_data = Instance(ArrayPlotData, transient=True)
    image_selection_tool = Instance(ImageRangeSelector)
    image_plot_selected = DelegatesTo("image_selection_tool")
    peak_vis = Instance(ScatterPlot, transient=True)
    scatter = Instance(ScatterPlot, transient=True)
    start_time = Range(low=0, high=100000.0, initial=0.0)
    window_size = Range(low=1.0, high=100000.0, initial=30.0)
    peak_editor = Instance(PeakPickingTool, transient=True)
    show_censor_regions = Bool(True)

    # parameters for processing the raw data before PT detecting
    bandpass_min = DelegatesTo("physiodata")
    bandpass_max = DelegatesTo("physiodata")
    smoothing_window_len = DelegatesTo("physiodata")
    smoothing_window = DelegatesTo("physiodata")
    pt_adjust = DelegatesTo("physiodata")
    peak_threshold = DelegatesTo("physiodata")
    apply_filter = DelegatesTo("physiodata")
    apply_diff_sq = DelegatesTo("physiodata")
    apply_smooth_ma = DelegatesTo("physiodata")
    peak_window = DelegatesTo("physiodata")
    qrs_source_signal = DelegatesTo("physiodata")
    qrs_power_signal = Property(Array,
                                depends_on=qrs_power_signal_dependencies)

    # Use a secondary signal to limit the search range
    use_secondary_heartbeat = DelegatesTo("physiodata")
    secondary_heartbeat = DelegatesTo("physiodata")
    secondary_heartbeat_abs = DelegatesTo("physiodata")
    secondary_heartbeat_pre_msec = DelegatesTo("physiodata")
    secondary_heartbeat_window = DelegatesTo("physiodata")
    secondary_heartbeat_window_len = DelegatesTo("physiodata")
    secondary_heartbeat_n_likelihood_bins = DelegatesTo("physiodata")
    aux_signal = Property(Array)  #, depends_on=aux_signal_dependencies)

    # Secondary ECG signal
    can_use_ecg2 = Bool(False)
    use_ECG2 = DelegatesTo("physiodata")
    ecg2_weight = DelegatesTo("physiodata")

    # The results of the peak search
    peak_indices = DelegatesTo("physiodata")
    peak_times = DelegatesTo("physiodata")
    peak_values = Array
    dne_peak_indices = DelegatesTo("physiodata")
    dne_peak_times = DelegatesTo("physiodata")
    dne_peak_values = Array
    # Holds the timeseries for the signal vs noise thresholds
    thr_times = Array
    thr_vals = Array

    # for pulling out data for aggregating ecg_ts, dzdt_ts and bp beats
    ecg_matrix = DelegatesTo("physiodata")

    # UI elements
    b_detect = Button(label="Detect QRS", transient=True)
    b_shape = Button(label="Shape-Based Tuning", transient=True)

    def __init__(self, **traits):
        super(PanTomkinsDetector, self).__init__(**traits)
        self.ecg_ts = TimeSeries(physiodata=self.physiodata, contains="ecg")
        # relevant data to be plotted
        self.ecg_time = self.ecg_ts.time

        self.can_use_ecg2 = "ecg2" in self.physiodata.contents

        # Which ECG signal to use?
        if self.qrs_source_signal == "ecg" or not self.can_use_ecg2:
            self.ecg_signal = normalize(self.ecg_ts.data)
            if self.can_use_ecg2:
                self.ecg2_signal = normalize(self.physiodata.ecg2_data)
            else:
                self.ecg2_signal = None
        else:
            self.ecg_signal = normalize(self.physiodata.ecg2_data)
            self.ecg2_signal = normalize(self.physiodata.ecg_data)

        self.thr_times = self.ecg_time[np.array([0, -1])]
        self.censored_intervals = self.physiodata.censored_intervals
        # graphics containers
        self.aux_window_graphics = []
        self.censored_region_graphics = []
        # If peaks are already available in the mea.mat,
        # they have already been marked
        if len(self.peak_times):
            self.dirty = False
            self.peak_values = self.ecg_signal[self.peak_indices]
        if self.dne_peak_indices is not None and len(self.dne_peak_indices):
            self.dne_peak_values = self.ecg_signal[self.dne_peak_indices]

    def _b_shape_fired(self):
        self.shape_based_tuning()

    def shape_based_tuning(self):
        logger.info("Running shape-based tuning")
        qrs_stack = peak_stack(self.peak_indices,
                               self.qrs_power_signal,
                               pre_msec=300,
                               post_msec=700,
                               sampling_rate=self.physiodata.ecg_sampling_rate)

    def _b_detect_fired(self):
        self.detect()

    def _show_censor_regions_changed(self):
        for crg in self.censored_region_graphics:
            crg.visible = self.show_censor_regions
        self.plot.request_redraw()

    @cached_property
    def _get_qrs_power_signal(self):
        if self.apply_filter:
            filtered_ecg = normalize(
                bandpass(self.ecg_signal, self.bandpass_min, self.bandpass_max,
                         self.ecg_ts.sampling_rate))
        else:
            filtered_ecg = self.ecg_signal

        # Differentiate and square the signal?
        if self.apply_diff_sq:
            # Differentiate the signal and square it
            diff_sq = np.ediff1d(filtered_ecg, to_begin=0)**2
        else:
            diff_sq = filtered_ecg

        # If we're to apply a Moving Average smoothing
        if self.apply_smooth_ma:
            # MA smoothing
            smooth_ma = smooth(diff_sq,
                               window_len=self.smoothing_window_len,
                               window=self.smoothing_window)
        else:
            smooth_ma = diff_sq
        if not self.use_ECG2:
            # for visualization purposes
            return normalize(smooth_ma)
        logger.info("Using 2nd ECG signal")

        # Use secondary ECG signal and combine QRS power
        if self.apply_filter:
            filtered_ecg2 = normalize(
                bandpass(self.ecg2_signal, self.bandpass_min,
                         self.bandpass_max, self.ecg_ts.sampling_rate))
        else:
            filtered_ecg2 = self.ecg2_signal
        if self.apply_diff_sq:
            diff_sq2 = np.ediff1d(filtered_ecg2, to_begin=0)**2
        else:
            diff_sq2 = filtered_ecg2

        if self.apply_smooth_ma:
            smooth_ma2 = smooth(diff_sq2,
                                window_len=self.smoothing_window_len,
                                window=self.smoothing_window)
        else:
            smooth_ma2 = diff_sq2

        return normalize(((1 - self.ecg2_weight) * smooth_ma +
                          self.ecg2_weight * smooth_ma2)**2)

    def _get_aux_signal(self):
        if not self.use_secondary_heartbeat: return np.array([])
        sig = getattr(self.physiodata, self.secondary_heartbeat + "_data")
        if self.secondary_heartbeat_abs:
            sig = np.abs(sig)
        if self.secondary_heartbeat_window_len > 0:
            sig = smooth(sig,
                         window_len=self.secondary_heartbeat_window_len,
                         window=self.secondary_heartbeat_window)
        return normalize(sig)

    @on_trait_change(algorithm_parameters)
    def params_edited(self):
        self.dirty = True

    def update_aux_window_graphics(self):
        if not hasattr(self, "ecg_trace") or not self.use_secondary_heartbeat:
            return
        #remove the old graphics
        for aux_window in self.aux_window_graphics:
            del self.ecg_trace.index.metadata[aux_window.metadata_name]
            self.ecg_trace.overlays.remove(aux_window)
        # add the new graphics
        for n, (start, end) in enumerate(self.aux_windows / 1000.):
            window_key = "aux%03d" % n
            self.aux_window_graphics.append(
                AuxSignalWindow(component=self.aux_trace,
                                metadata_name=window_key))
            self.ecg_trace.overlays.append(self.aux_window_graphics[-1])
            self.ecg_trace.index.metadata[window_key] = start, end

    def get_aux_windows(self):
        aux_peaks = find_peaks(self.aux_signal)
        aux_pre_peak = aux_peaks - self.secondary_heartbeat_pre_msec
        aux_windows = np.column_stack([aux_pre_peak, aux_peaks])
        return aux_windows

    def detect(self):
        """
        Implementation of the Pan Tomkins QRS detector
        """
        logger.info("Beginning QRS Detection")
        t0 = time.time()

        # The original paper used a different method for finding peaks
        smoothdiff = normalize(np.ediff1d(self.qrs_power_signal, to_begin=0))
        peaks = find_peaks(smoothdiff)
        peak_amps = self.qrs_power_signal[peaks]

        # Part 2: getting rid of useless peaks
        # ====================================
        # There are lots of small, irrelevant peaks that need to
        # be discarded.
        # 2a) remove peaks occurring in censored intervals
        censor_peak_mask = censor_peak_times(
            self.ecg_ts.
            censored_regions,  #TODO: switch to physiodata.censored_intervals   
            self.ecg_time[peaks])
        n_removed_peaks = peaks.shape[0] - censor_peak_mask.sum()
        logger.info("%d/%d potential peaks outside %d censored intervals",
                    n_removed_peaks, peaks.shape[0],
                    len(self.ecg_ts.censored_regions))
        peaks = peaks[censor_peak_mask]
        peak_amps = peak_amps[censor_peak_mask]

        # 2b) if a second signal is used, make sure the ecg peaks are
        #     near aux signal peaks
        if self.use_secondary_heartbeat:
            self.aux_windows = self.get_aux_windows()
            aux_mask = times_contained_in(peaks, self.aux_windows)
            logger.info("Using secondary signal")
            logger.info("%d aux peaks detected", self.aux_windows.shape[0])
            logger.info("%d/%d peaks contained in second signal window",
                        aux_mask.sum(), peaks.shape[0])
            peaks = peaks[aux_mask]
            peak_amps = peak_amps[aux_mask]

        # 2c) use otsu's method to find a cutoff value
        peak_amp_thr = threshold_otsu(peak_amps) + self.pt_adjust
        otsu_mask = peak_amps > peak_amp_thr
        logger.info("otsu threshold: %.7f", peak_amp_thr)
        logger.info("%d/%d peaks survive", otsu_mask.sum(), peaks.shape[0])
        peaks = peaks[otsu_mask]
        peak_amps = peak_amps[otsu_mask]

        # 3) Make sure there is only one peak in each secondary window
        #    this is accomplished by estimating the distribution of peak
        #    times relative to aux window starts
        if self.use_secondary_heartbeat:
            npeaks = peaks.shape[0]
            # obtain a distribution of times and amplitudes
            rel_peak_times = np.zeros_like(peaks)
            window_subsets = []
            for start, end in self.aux_windows:
                mask = np.flatnonzero((peaks >= start) & (peaks <= end))
                if len(mask) == 0: continue
                window_subsets.append(mask)
                rel_peak_times[mask] = peaks[mask] - start

            # how common are relative peak times relative to window starts
            densities, bins = np.histogram(
                rel_peak_times,
                range=(0, self.secondary_heartbeat_pre_msec),
                bins=self.secondary_heartbeat_n_likelihood_bins,
                density=True)

            likelihoods = densities[np.clip(
                np.digitize(rel_peak_times, bins) - 1, 0,
                self.secondary_heartbeat_n_likelihood_bins - 1)]
            _peaks = []
            # Pull out the maximal peak
            for subset in window_subsets:
                # If there's only a single peak contained, no need for math
                if len(subset) == 1:
                    _peaks.append(peaks[subset[0]])
                    continue

                _peaks.append(peaks[subset[np.argmax(likelihoods[subset])]])
            peaks = np.array(_peaks)
            logger.info("Only 1 peak per aux window allowed:" + \
                   " %d/%d peaks remaining",peaks.shape[0],npeaks)

        # Check that no two peaks are too close:
        peak_diffs = np.ediff1d(peaks, to_begin=500)
        peaks = peaks[peak_diffs > 200]  # Cutoff is 300BPM

        # Stack the peaks and see if the original data has a higher value
        raw_stack = peak_stack(peaks,
                               self.ecg_signal,
                               pre_msec=self.peak_window,
                               post_msec=self.peak_window,
                               sampling_rate=self.ecg_ts.sampling_rate)
        adj_factors = np.argmax(raw_stack, axis=1) - self.peak_window
        peaks = peaks + adj_factors

        self.peak_indices = peaks
        self.peak_values = self.ecg_signal[peaks]
        self.peak_times = self.ecg_time[peaks]
        self.thr_vals = np.array([peak_amp_thr] * 2)
        t1 = time.time()
        # update the scatterplot if we're interactive
        if self.plot_data is not None:
            self.plot_data.set_data("peak_times", self.peak_times)
            self.plot_data.set_data("peak_values", self.peak_values)
            self.plot_data.set_data("qrs_power", self.qrs_power_signal)
            self.plot_data.set_data("aux_signal", self.aux_signal)
            self.plot_data.set_data("thr_vals", self.thr_vals)
            self.plot_data.set_data("thr_times", self.thr_vals)
            self.plot_data.set_data(
                "thr_times", np.array([self.ecg_time[0], self.ecg_time[-1]])),
            self.update_aux_window_graphics()
            self.plot.request_redraw()
            self.image_plot_data.set_data("imagedata", self.ecg_matrix)
            self.image_plot.request_redraw()
        else:
            print "plot data is none"
        logger.info("found %d QRS complexes in %.3f seconds",
                    len(self.peak_indices), t1 - t0)
        self.dirty = False

    def _plot_default(self):
        """
        Creates a plot of the ecg_ts data and the signals derived during
        the Pan Tomkins algorithm
        """
        # Create plotting components
        self.plot_data = ArrayPlotData(
            time=self.ecg_time,
            raw_ecg=self.ecg_signal,
            qrs_power=self.qrs_power_signal,
            aux_signal=self.aux_signal,
            # Ensembleable peaks
            peak_times=self.peak_times,
            peak_values=self.peak_values,
            # Non-ensembleable, useful for HR, HRV
            dne_peak_times=self.dne_peak_times,
            dne_peak_values=self.dne_peak_values,
            # Threshold slider bar
            thr_times=np.array([self.ecg_time[0], self.ecg_time[-1]]),
            thr_vals=np.array([0, 0]))

        plot = Plot(self.plot_data, use_backbuffer=True)
        self.aux_trace = plot.plot(("time", "aux_signal"),
                                   color="purple",
                                   alpha=0.6,
                                   line_width=2)[0]
        self.qrs_power_trace = plot.plot(("time", "qrs_power"),
                                         color="blue",
                                         alpha=0.8)[0]
        self.ecg_trace = plot.plot(("time", "raw_ecg"),
                                   color="red",
                                   line_width=2)[0]

        # Load the censor regions and add them to the plot
        for n, (start, end) in enumerate(self.censored_intervals):
            censor_key = "censor%03d" % n
            self.censored_region_graphics.append(
                StaticCensoredInterval(component=self.qrs_power_trace,
                                       metadata_name=censor_key))
            self.ecg_trace.overlays.append(self.censored_region_graphics[-1])
            self.ecg_trace.index.metadata[censor_key] = start, end

        # Line showing
        self.threshold_trace = plot.plot(("thr_times", "thr_vals"),
                                         color="black",
                                         line_width=3)[0]

        # Plot for plausible peaks
        self.ppeak_vis = plot.plot(("dne_peak_times", "dne_peak_values"),
                                   type="scatter",
                                   marker="diamond")[0]

        # Make a scatter plot where you can edit the peaks
        self.peak_vis = plot.plot(("peak_times", "peak_values"),
                                  type="scatter")[0]
        self.scatter = plot.components[-1]
        self.peak_editor = PeakPickingTool(self.scatter)
        self.scatter.tools.append(self.peak_editor)
        # when the user adds or removes a point, automatically extract
        self.on_trait_event(self._change_peaks, "peak_editor.done_selecting")

        self.scatter.overlays.append(
            PeakPickingOverlay(component=self.scatter))

        return plot

    # TODO: have this update other variables in physiodata: hand_labeled, mea, etc
    def _delete_peaks(self, peaks_to_delete):
        pass

    def _add_peaks(self, peaks_to_add):
        pass

    # End TODO

    def _change_peaks(self):
        if self.dirty:
            messagebox("You shouldn't edit peaks until you've run Detect QRS")
        interval = self.peak_editor.selection
        if interval is None: return
        mode = self.peak_editor.selection_purpose
        logger.info("PeakPickingTool entered %s mode over %s", mode,
                    str(interval))
        if mode == "delete":
            # Do it for real peaks
            ok_peaks = np.logical_not(
                np.logical_and(self.peak_times > interval[0],
                               self.peak_times < interval[1]))
            self.peak_indices = self.peak_indices[ok_peaks]
            self.peak_values = self.peak_values[ok_peaks]
            self.peak_times = self.peak_times[ok_peaks]
            # And do it for
            ok_dne_peaks = np.logical_not(
                np.logical_and(self.dne_peak_times > interval[0],
                               self.dne_peak_times < interval[1]))
            self.dne_peak_indices = self.dne_peak_indices[ok_dne_peaks]
            self.dne_peak_values = self.dne_peak_values[ok_dne_peaks]
            self.dne_peak_times = self.dne_peak_times[ok_dne_peaks]
        else:
            # Find the signal contained in the selection
            sig = np.logical_and(self.ecg_time > interval[0],
                                 self.ecg_time < interval[1])
            sig_inds = np.flatnonzero(sig)
            selected_sig = self.ecg_signal[sig_inds]
            # find the peak in the selected region
            peak_ind = sig_inds[0] + np.argmax(selected_sig)

            # If we're in add peak mode, always make sure
            # that only unique peaks get added:
            if mode == "add":
                real_peaks = np.sort(
                    np.unique(self.peak_indices.tolist() + [peak_ind]))
                self.peak_indices = real_peaks
                self.peak_values = self.ecg_signal[real_peaks]
                self.peak_times = self.ecg_time[real_peaks]
            else:
                real_dne_peaks = np.sort(
                    np.unique(self.dne_peak_indices.tolist() + [peak_ind]))
                self.dne_peak_indices = real_dne_peaks
                self.dne_peak_values = self.ecg_signal[real_dne_peaks]
                self.dne_peak_times = self.ecg_time[real_dne_peaks]

        # update the scatterplot if we're interactive
        if not self.plot_data is None:
            self.plot_data.set_data("peak_times", self.peak_times)
            self.plot_data.set_data("peak_values", self.peak_values)
            self.plot_data.set_data("dne_peak_times", self.dne_peak_times)
            self.plot_data.set_data("dne_peak_values", self.dne_peak_values)
            self.image_plot_data.set_data("imagedata", self.ecg_matrix)
            self.image_plot.request_redraw()

    @on_trait_change("window_size,start_time")
    def update_plot_range(self):
        self.plot.index_range.high = self.start_time + self.window_size
        self.plot.index_range.low = self.start_time

    @on_trait_change("image_plot_selected")
    def snap_to_image_selection(self):
        if not self.peak_times.size > 0: return
        tmin = self.peak_times[int(self.image_selection_tool.ymin)] - 2.
        tmax = self.peak_times[int(self.image_selection_tool.ymax)] + 2.
        logger.info("selection tool sends data to (%.2f, %.2f)", tmin, tmax)
        self.plot.index_range.low = tmin - 2.
        self.plot.index_range.high = tmax + 2.

    def _image_plot_default(self):
        # for image plot
        img = self.ecg_matrix
        if self.ecg_matrix.size == 0:
            img = np.zeros((100, 100))
        self.image_plot_data = ArrayPlotData(imagedata=img)
        plot = Plot(self.image_plot_data)
        self.image_selection_tool = ImageRangeSelector(component=plot,
                                                       tool_mode="range",
                                                       axis="value",
                                                       always_on=True)
        plot.img_plot("imagedata",
                      colormap=jet,
                      name="plot1",
                      origin="bottom left")[0]
        plot.overlays.append(self.image_selection_tool)
        return plot

    detection_params_group = VGroup(
        HGroup(
            Group(
                VGroup(Item("use_secondary_heartbeat"),
                       Item("peak_window"),
                       Item("apply_filter"),
                       Item("bandpass_min"),
                       Item("bandpass_max"),
                       Item("smoothing_window_len"),
                       Item("smoothing_window"),
                       Item("pt_adjust"),
                       Item("apply_diff_sq"),
                       label="Pan Tomkins"),
                VGroup(  # Items for MultiSignal detection
                    Item("secondary_heartbeat"),
                    Item("secondary_heartbeat_pre_msec"),
                    Item("secondary_heartbeat_abs"),
                    Item("secondary_heartbeat_window"),
                    Item("secondary_heartbeat_window_len"),
                    label="MultiSignal Detection",
                    enabled_when="use_secondary_heartbeat"),
                VGroup(  # Items for two ECG Signals
                    Item("use_ECG2"),
                    Item("ecg2_weight"),
                    label="Multiple ECG",
                    enabled_when="can_use_ecg2"),
                label="Beat Detection Options",
                layout="tabbed",
                show_border=True,
                springy=True),
            Item("image_plot", editor=ComponentEditor(), width=300),
            show_labels=False),
        Item("b_detect", enabled_when="dirty"),
        Item("b_shape"),
        show_labels=False)

    plot_group = Group(
        Group(Item("plot", editor=ComponentEditor(), width=800),
              show_labels=False), Item("start_time"), Item("window_size"))
    traits_view = MEAPView(
        VSplit(
            plot_group,
            detection_params_group,
            #orientation="vertical"
        ),
        resizable=True,
        buttons=[OKButton, CancelButton],
        title="Detect Heart Beats")
Beispiel #5
0
class RespirationProcessor(HasTraits):
    # Parameters
    resp_polort = DelegatesTo("physiodata")
    resp_high_freq_cutoff = DelegatesTo("physiodata")

    time = DelegatesTo("physiodata", "processed_respiration_time")

    # Data object and arrays to save
    physiodata = Instance(PhysioData)
    respiration_cycle = DelegatesTo("physiodata")
    respiration_amount = DelegatesTo("physiodata")

    # For visualization
    plots = Instance(VPlotContainer, transient=True)

    # Respiration plot
    resp_plot = Instance(Plot, transient=True)
    resp_plot_data = Instance(ArrayPlotData, transient=True)
    resp_signal = Array()  # Could be filtered z0 or resp belt
    raw_resp_signal = Array()  # Could be filtered z0 or resp belt
    resp_signal = Property(
        Array,
        depends_on="physiodata.resp_polort,physiodata.resp_high_freq_cutoff")
    polyfilt_resp = Array()
    lpfilt_resp = DelegatesTo("physiodata", "processed_respiration_data")
    resp_inhale_begin_times = DelegatesTo("physiodata")
    resp_inhale_begin_values = Array
    resp_exhale_begin_times = DelegatesTo("physiodata")
    resp_exhale_begin_values = Array

    z0_plot = Instance(Plot, transient=True)
    z0_plot_data = Instance(ArrayPlotData, transient=True)
    raw_z0_signal = Array()
    resp_corrected_z0 = DelegatesTo("physiodata")

    dzdt_plot = Instance(Plot, transient=True)
    dzdt_plot_data = Instance(ArrayPlotData, transient=True)
    raw_dzdt_signal = Array()
    resp_corrected_dzdt = DelegatesTo("physiodata")

    start_time = Range(low=0, high=10000.0, initial=0.0)
    window_size = Range(low=1.0, high=300.0, initial=30.0)

    state = Enum("unusable", "z0", "resp", "z0_resp")
    dirty = Bool(True)

    b_process = Button(label="Process Respiration")

    def __init__(self, **traits):
        super(RespirationProcessor, self).__init__(**traits)
        resp_inc = "respiration" in self.physiodata.contents
        z0_inc = "z0" in self.physiodata.contents
        if z0_inc and resp_inc:
            self.state = "z0_resp"
            messagebox("Using both respiration belt and ICG")
            z0_signal = self.physiodata.z0_data
            dzdt_signal = self.physiodata.dzdt_data
            resp_signal = self.physiodata.respiration_data
        elif resp_inc and not z0_inc:
            self.state = "resp"
            messagebox("Only respiration belt data will be used.")
            resp_signal = self.physiodata.resp_data
            z0_signal = None
            dzdt_signal = None
        elif z0_inc and not resp_inc:
            self.state = "z0"
            messagebox("Using only z0 channel to estimate respiration")
            resp_signal = self.physiodata.z0_data
            z0_signal = self.physiodata.z0_data
            dzdt_signal = self.physiodata.dzdt_data
            self.resp_polort = 1
        else:
            self.state = "unusable"
            messagebox("No respiration belt or z0 channels found")

        # Establish  the maximum shared length of resp_inhale_begin_values
        signals = [sig for sig in (resp_signal, z0_signal,dzdt_signal) if \
                sig is not None]
        if len(signals) > 1:
            minlen = min([len(sig) for sig in signals])
        else:
            minlen = len(signals[0])

        # Establish a time array
        if resp_inc:
            self.time = TimeSeries(physiodata=self.physiodata,
                                   contains="respiration").time[:minlen]
        else:
            self.time = TimeSeries(physiodata=self.physiodata,
                                   contains="z0").time[:minlen]

        # Get out the final signals
        if z0_inc:
            self.raw_resp_signal = z0_signal[:minlen].copy()
            self.raw_resp_signal[:50] = self.raw_resp_signal.mean()
            self.raw_z0_signal = z0_signal[:minlen].copy()
            self.raw_z0_signal[:50] = self.raw_z0_signal.mean()
            self.raw_dzdt_signal = dzdt_signal[:minlen].copy()
            self.raw_dzdt_signal[:50] = self.raw_dzdt_signal.mean()
        if resp_inc:
            self.raw_resp_signal = resp_signal[:minlen].copy()
        # if data already exists, it can't be dirty
        if self.physiodata.resp_exhale_begin_times.size > 0: self.dirty = False
        self.on_trait_change(self._parameter_changed,
                             "resp_polort,resp_high_freq_cutoff")

    @cached_property
    def _get_resp_signal(self):
        resp_signal = winsorize(self.raw_resp_signal)
        return (resp_signal - resp_signal.mean()) / resp_signal.std()

    def _parameter_changed(self):
        self.dirty = True

    def _b_process_fired(self):
        self.process()

    def process(self):
        """
        processes the respiration timeseries
        """
        sampling_rate = 1. / (self.time[1] - self.time[0])
        # Respiration belts can lose tension over time.
        # This removes linear trends
        if self.resp_polort > 0:
            pfit = legendre_detrend(self.resp_signal, self.resp_polort)
            if not pfit.shape == self.time.shape:
                messagebox("Legendre detrend failed")
                return
            self.polyfilt_resp = pfit
        else:
            self.polyfilt_resp = self.resp_signal

        lpd = lowpass(self.polyfilt_resp, self.resp_high_freq_cutoff,
                      sampling_rate)
        if not lpd.shape == self.time.shape:
            messagebox("lowpass filter failed")
            return

        self.lpfilt_resp = (lpd - lpd.mean()) / lpd.std()

        resp_inhale_begin_indices, resp_exhale_begin_indices = find_peaks(
            self.lpfilt_resp, maxima=True, minima=True)
        self.resp_inhale_begin_times = self.time[resp_inhale_begin_indices]
        self.resp_exhale_begin_times = self.time[resp_exhale_begin_indices]
        self.resp_corrected_z0 = regress_out(self.raw_z0_signal,
                                             self.lpfilt_resp)
        self.resp_corrected_dzdt = regress_out(self.raw_dzdt_signal,
                                               self.lpfilt_resp)

        # update the scatterplot if we're interactive
        if self.resp_plot_data is not None:
            self.resp_plot_data.set_data("inhale_times",
                                         self.resp_inhale_begin_times)
            self.resp_plot_data.set_data(
                "inhale_values", self.lpfilt_resp[resp_inhale_begin_indices])
            self.resp_plot_data.set_data("exhale_times",
                                         self.resp_exhale_begin_times)
            self.resp_plot_data.set_data(
                "exhale_values", self.lpfilt_resp[resp_exhale_begin_indices])
            self.resp_plot_data.set_data("lpfilt_resp", self.lpfilt_resp)
            self.dzdt_plot_data.set_data("cleaned_data",
                                         self.resp_corrected_dzdt)
            self.z0_plot_data.set_data("cleaned_data", self.resp_corrected_z0)
        else:
            print "plot data is none"

        # Update the respiration cycles cached_pro
        times = np.concatenate([
            np.zeros(1), self.resp_exhale_begin_times,
            self.resp_inhale_begin_times, self.time[-1, np.newaxis]
        ])
        vals = np.concatenate([
            np.zeros(1),
            np.zeros_like(resp_exhale_begin_indices),
            0.5 * np.ones_like(resp_inhale_begin_indices),
            np.zeros(1)
        ])
        srt = np.argsort(times)
        terp = interp1d(times[srt], vals[srt])
        self.respiration_cycle = terp(self.time)
        self.respiration_amount = np.abs(
            np.ediff1d(self.respiration_cycle, to_begin=0))

    def _resp_plot_default(self):
        # Create plotting components
        self.resp_plot_data = ArrayPlotData(
            time=self.time,
            inhale_times=self.resp_inhale_begin_times,
            inhale_values=self.resp_inhale_begin_values,
            exhale_times=self.resp_exhale_begin_times,
            exhale_values=self.resp_exhale_begin_values,
            filtered_resp=self.resp_signal,
            lpfilt_resp=self.lpfilt_resp)
        plot = Plot(self.resp_plot_data)
        plot.plot(("time", "filtered_resp"), color="blue", line_width=1)
        plot.plot(("time", "lpfilt_resp"), color="green", line_width=1)
        # Plot the inhalation peaks
        plot.plot(("inhale_times", "inhale_values"),
                  type="scatter",
                  marker="square")
        plot.plot(("exhale_times", "exhale_values"),
                  type="scatter",
                  marker="circle")
        plot.title = "Respiration"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _z0_plot_default(self):
        self.z0_plot_data = ArrayPlotData(time=self.time,
                                          raw_data=self.raw_z0_signal,
                                          cleaned_data=self.resp_corrected_z0)
        plot = Plot(self.z0_plot_data)
        plot.plot(("time", "raw_data"), color="blue", line_width=1)
        plot.plot(("time", "cleaned_data"), color="green", line_width=1)
        plot.title = "z0"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _dzdt_plot_default(self):
        """
        Creates a plot of the ecg_ts data and the signals derived during
        the Pan Tomkins algorithm
        """
        # Create plotting components
        self.dzdt_plot_data = ArrayPlotData(
            time=self.time,
            raw_data=self.raw_dzdt_signal,
            cleaned_data=self.resp_corrected_dzdt)
        plot = Plot(self.dzdt_plot_data)
        plot.plot(("time", "raw_data"), color="blue", line_width=1)
        plot.plot(("time", "cleaned_data"), color="green", line_width=1)
        plot.title = "dZ/dt"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.padding = 20
        return plot

    def _plots_default(self):
        plots_to_include = ()
        if self.state in ("z0_resp", "z0"):
            self.index_range = self.resp_plot.index_range
            self.dzdt_plot.index_range = self.index_range
            self.z0_plot.index_range = self.index_range
            plots_to_include = [self.resp_plot, self.z0_plot, self.dzdt_plot]
        elif self.state == "resp":
            self.index_range = self.resp_plot.index_range
            plots_to_include = [self.resp_plot]
        return VPlotContainer(*plots_to_include)

    @on_trait_change("window_size,start_time")
    def update_plot_range(self):
        self.resp_plot.index_range.high = self.start_time + self.window_size
        self.resp_plot.index_range.low = self.start_time

    proc_params_group = Group(Group(
        Item("resp_polort"),
        Item("resp_high_freq_cutoff"),
        Item("b_process",
             show_label=False,
             enabled_when="state != unusable and dirty"),
        label="Resp processing options",
        show_border=True,
        orientation="vertical",
        springy=True),
                              orientation="horizontal")

    plot_group = Group(
        Group(Item("plots", editor=ComponentEditor(), width=800, height=700),
              show_labels=False), Item("start_time"), Item("window_size"))
    traits_view = MEAPView(
        VSplit(
            plot_group,
            proc_params_group,
        ),
        resizable=True,
        win_title="Process Respiration Data",
    )
Beispiel #6
0
class MEAPConfig(HasTraits):

    # Parameters for point marking
    apply_ecg_smoothing = CBool(True)
    ecg_smoothing_window_len = Int(5) # NOTE: Changed in 1.1 from 20
    apply_imp_smoothing = CBool(True)
    imp_smoothing_window_len = Int(40)
    apply_bp_smoothing = CBool(True)
    bp_smoothing_window_len = Int(80)

    # Parameters for waveform extraction
    peak_window = CInt(80) #Range(low=5,high=400, value= 200)
    ecg_pre_peak = CInt(300) #Range(low=50, high=500, value=300)
    ecg_post_peak = CInt(400) #Range(low=100, high=700, value=400)
    dzdt_pre_peak = CInt(300) #Range(50, 500, value=300)
    dzdt_post_peak = CInt(700) #Range(100, 1000, value=700)
    doppler_pre_peak = CInt(300) #Range(50, 500, value=300)
    doppler_post_peak = CInt(700) #Range(100, 1000, value=700)
    bp_pre_peak = CInt(300) #Range(50, 500, value=300)
    bp_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    systolic_pre_peak = CInt(300) #Range(50, 500, value=300)
    systolic_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    diastolic_pre_peak = CInt(300) #Range(50, 500, value=300)
    diastolic_post_peak = CInt(1000) #Range(100, 2500, value=1200)
    stroke_volume_equation = Enum("Kubicek","Sramek-Bernstein")
    extraction_group = Group(
        Group(
            Item("peak_window"),
            Item("ecg_pre_peak"),
            Item("ecg_post_peak"),
            Item("dzdt_pre_peak"),
            Item("dzdt_post_peak"),
            Item("bp_pre_peak"),
            Item("bp_post_peak"),
            Item("stroke_volume_equation"),
            orientation="vertical",
            show_border=True,
            label = "Waveform Extraction Parameters",
            springy=True
            ),
        Group(
            Item("mhd_bandpass_min"),
            Item("mhd_bandpass_max"),
            Item("mhd_smoothing_window_len"),
            Item("mhd_smoothing_window"),
            Item("qrs_to_mhd_ratio"),
            Item("combined_smoothing_window_len"),
            Item("combined_smoothing_window"),
            enabled_when="subject_in_mri"
        )
    )

    # Respiration analysis parameters
    process_respiration = CBool(True)
    resp_polort = CInt(7)
    resp_high_freq_cutoff = CFloat(0.35)
    regress_out_resp = Bool(False)

    # parameters for processing the raw data before PT detecting
    # MRI-specific
    subject_in_mri = CBool(False)
    peak_detection_algorithm = Enum("Pan Tomkins 83", "Multisignal", "ECG2")

    # PanTomkins algorithm
    bandpass_min = Range(low=1,high=200,initial=5,value=5)
    bandpass_max = Range(low=1,high=200,initial=15,value=15)
    smoothing_window_len = Range(low=10,high=1000,initial=100,value=100)
    smoothing_window = Enum(SMOOTHING_WINDOWS)
    pt_adjust = Range(low=-2.,high=2.,value=0.00)
    peak_threshold=CFloat
    apply_filter = CBool(True)
    apply_diff_sq = CBool(True)
    apply_smooth_ma = CBool(True)
    pt_params_group = Group(
        Item("apply_filter"),
        Item("bandpass_min"),#editor=RangeEditor(enter_set=True)),
        Item("bandpass_max"),#editor=RangeEditor(enter_set=True)),
        Item("smoothing_window_len"),#editor=RangeEditor(enter_set=True)),
        Item("smoothing_window"),
        Item("pt_adjust"),#editor=RangeEditor(enter_set=True)),
        Item("apply_diff_sq"),
        Item("subject_in_mri"),
        label="R peak detection options",
        show_border=True,
        orientation="vertical",
        springy=True
    )

    # Second signal heartbeat detection
    use_secondary_heartbeat = CBool(False)
    secondary_heartbeat = Enum("dzdt", "pulse_ox", "bp")
    secondary_heartbeat_pre_msec = CInt(400)
    secondary_heartbeat_abs = CBool(True)
    secondary_heartbeat_window = Enum(SMOOTHING_WINDOWS)
    secondary_heartbeat_window_len = CInt(801)
    secondary_heartbeat_n_likelihood_bins = CInt(15)

    # ECG2 parameters
    use_ECG2 = CBool(False)
    qrs_signal_source = Enum("ecg", "ecg2")
    ecg2_weight = Range(low=0., high=1.,value=0.5)

    # Moving Ensembling Parameters
    mea_window_type = Enum("Seconds","Beats")
    mea_n_neighbors = Range(low=0, high=60, value=8)
    mea_window_secs = Range(low=1., high=60, value=15.)
    mea_exp_power = Enum(2,3,4,5,6)
    mea_func_name = Enum("linear","exponential","flat")
    mea_weight_direction = Enum("symmetric", "before", "after")
    mea_smooth_hr = CBool(True)
    use_trimmed_co = CBool(True)

    #bpoint classifier parameters
    bpoint_classifier_pre_point_msec = CInt(20)
    bpoint_classifier_post_point_msec = CInt(20)
    bpoint_classifier_sample_every_n_msec = CInt(1)
    bpoint_classifier_false_distance_min = CInt(5)
    bpoint_classifier_use_bpoint_prior = CBool(True)
    bpoint_classifier_include_derivative = CBool(True)

    # Doppler D point config
    db_point_type = Enum("min", "max")
    db_point_window_len = CInt(20)
    dx_point_type = Enum("min", "max")
    dx_point_window_len = CInt(20)

    # SRVF-warping parameters
    srvf_lambda = CFloat(0.0)
    srvf_max_karcher_iterations = CInt(15)
    srvf_update_min = CFloat(0.01)
    srvf_karcher_mean_subset_size = CInt(30)
    srvf_multi_mode_variance_cutoff = CFloat(0.3)
    srvf_use_moving_ensembled = CBool(False)
    dzdt_num_inputs_to_group_warping = CInt(25)
    srvf_t_min = CInt(0) # Time in msec relative to R
    srvf_t_max = CInt(150) # Also time in msec relative to R
    bspline_before_warping = CBool(True)
    n_modes = CInt(5)
    max_kmeans_iterations = CInt(5)
Beispiel #7
0
class MEAPTimeseries(HasTraits):
    physiodata = Instance(PhysioData)
    plot = Instance(Plot, transient=True)
    plotdata = Instance(ArrayPlotData, transient=True)
    selection = Instance(BeatSelection, transient=True)
    signal = Str
    selected_beats = Array
    index_datasourse = Instance(DataRange1D)
    marker_size = Array
    metadata_name = Str
    selected = Event
    selected_range = Tuple

    traits_view = MEAPView(Group(Item('plot',
                                      editor=ComponentEditor(size=size),
                                      show_label=False),
                                 orientation="vertical"),
                           resizable=True,
                           win_title=title)

    def __init__(self, **traits):
        super(MEAPTimeseries, self).__init__(**traits)
        self.metadata_name = self.signal + "_selection"

    def _selection_changed(self):
        selection = self.selection.selection
        logger.info("%s selection changed to %s", self.signal, str(selection))
        if selection is None or len(selection) == 0:
            return
        self.selected_range = selection
        self.selected = True

    def _plot_default(self):

        plot = Plot(self.plotdata)

        if self.signal in ("tpr", "co", "sv"):
            rc_signal = "resp_corrected_" + self.signal
            if rc_signal in self.plotdata.arrays and self.plotdata.arrays[
                    rc_signal].size > 0:
                # Plot the resp-uncorrected version underneath
                plot.plot(("peak_times", self.signal),
                          type="line",
                          color="purple")
                signal = rc_signal
            signal = self.signal
        elif self.signal == "hr":
            plot.plot(("peak_times", self.signal), type="line", color="purple")
            signal = "mea_hr"
        else:
            signal = self.signal

        # Create the plot
        plot.plot(("peak_times", signal), type="line", color="blue")
        plot.plot(("peak_times", signal, "beat_type"),
                  type="cmap_scatter",
                  color_mapper=jet,
                  name="my_plot",
                  marker="circle",
                  border_visible=False,
                  outline_color="transparent",
                  bg_color="white",
                  index_sort="ascending",
                  marker_size=3,
                  fill_alpha=0.8)

        # Tweak some of the plot properties
        plot.title = self.signal  #"Scatter Plot With Lasso Selection"
        plot.title_position = "right"
        plot.title_angle = 270
        plot.line_width = 1
        plot.padding = 20

        # Right now, some of the tools are a little invasive, and we need the
        # actual ScatterPlot object to give to them
        my_plot = plot.plots["my_plot"][0]

        # Attach some tools to the plot
        self.selection = BeatSelection(component=my_plot)
        my_plot.active_tool = self.selection
        selection_overlay = RangeSelectionOverlay(component=my_plot)
        my_plot.overlays.append(selection_overlay)

        # Set up the trait handler for the selection
        self.selection.on_trait_change(self._selection_changed,
                                       'selection_completed')

        return plot
Beispiel #8
0
class Importer(HasTraits):
    path = File
    channels = List(Instance(Channel))
    mapping_txt = File()
    mapper = Dict()
    
    # --- Subject information
    subject_age = CFloat(20.)
    subject_gender = Enum("M","F")
    subject_weight = CFloat(135.,label="Weight (lbs)")
    subject_height_ft = CInt(5,label="Height (ft)",
                        desc="Subject's height in feet")
    subject_height_in = CInt(10,label = "Height (in)",
                        desc="Subject's height in inches")
    subject_electrode_distance_front = CFloat(32,
                        label="Impedance electrode distance (front)")
    subject_electrode_distance_back = CFloat(32,
                        label="Impedance electrode distance (back)")
    subject_electrode_distance_right = CFloat(32,
                        label="Impedance electrode distance (right)")
    subject_electrode_distance_left = CFloat(32,
                        label="Impedance electrode distance (left)")
    subject_in_mri = CBool(False)
    subject_control_base_impedance = CFloat(0.,label="Control Imprdance",
                        desc="If in MRI, store the z0 value from outside the MRI")
    subject_resp_max = CFloat(100.,label="Respiration circumference max (cm)")
    subject_resp_min = CFloat(0.,label="Respiration circumference min (cm)")
    
    def _channel_map_default(self):
        return ChannelMapper()
    
    def _mapping_txt_changed(self):
        if not os.path.exists(self.mapping_txt): return
        with open(self.mapping_txt) as f:
            lines = [l.strip() for l in f]
        self.mapper = {}
        for line in lines:
            strip = line.split("->")
            if not len(strip) == 2:
                logger.info("Unable to parse line: %s", line)
                continue
            map_from, map_to = [l.strip() for l in strip]
            if map_to in SUPPORTED_SIGNALS:
                self.mapper[map_from] = map_to
                logger.info("mapping %s to %s",map_from,map_to)
            else:
                logger.info("Unrecognized channel: %s", map_to)
        
    traits_view = MEAPView(
        Group(
            Item("channels", editor=channel_table),
            label="Specify channel contents",
            show_border=True,
            show_labels=False
            ),
         Group(
             Item("subject_age"), Item("subject_gender"),
             Item("subject_height_ft"), Item("subject_height_in"),
             Item("subject_weight"),
             Item("subject_electrode_distance_front"),
             Item("subject_electrode_distance_back"),
             Item("subject_electrode_distance_left"),
             Item("subject_electrode_distance_right"),
             Item("subject_resp_max"),Item("subject_resp_min"),
             Item("subject_in_mri"),Item('subject_control_base_impedance'),
             label="Participant Measurements"
             ),
        buttons=[OKButton, CancelButton],
        resizable=True,
        win_title="Import Data",
        )
    
    
    def guess_channel_contents(self):
        mapped = set([v for k,v in self.mapper.iteritems()])
        for chan in self.channels:
            if chan.name in self.mapper:
                chan.contains = self.mapper[chan.name]
                continue
            cname = chan.name.lower()
            if "magnitude" in cname and not "z0" in mapped:
                chan.contains = "z0"
            elif "ecg" in cname and not "ecg" in mapped:
                chan.contains = "ecg"
            elif "dz/dt" in cname or "derivative" in cname and \
                 not "dzdt" in mapped:
                chan.contains = "dzdt"
            elif "diastolic" in cname and not "diastolic" in mapped:
                chan.contains = "diastolic"
            elif "systolic" in cname and not "systolic" in mapped:
                chan.contains = "systolic"
            elif "blood pressure" in cname or "bp" in cname and \
                 not "bp" in mapped:
                chan.contains = "bp"
            elif "resp" in cname and not "respiration" in mapped:
                chan.contains = "respiration"
            elif "trigger" in cname and not "mri_trigger" in mapped:
                chan.contains = "mri_trigger"
            elif "stimulus" in cname and not "event" in mapped:
                chan.contains = "event"
            
                
    def subject_data(self):
        
        return dict(subject_age = self.subject_age,
            subject_gender = self.subject_gender,
            subject_weight = self.subject_weight,
            subject_height_ft = self.subject_height_ft, 
            subject_height_in = self.subject_height_in,
            subject_electrode_distance_front = self.subject_electrode_distance_front,  
            subject_electrode_distance_back = self.subject_electrode_distance_back ,
            subject_electrode_distance_left = self.subject_electrode_distance_left ,
            subject_electrode_distance_right = self.subject_electrode_distance_right ,
            subject_resp_max = self.subject_resp_max,
            subject_resp_min = self.subject_resp_min,
            subject_in_mri = self.subject_in_mri
            )
    
    def get_physiodata(self):
        raise NotImplementedError("Must be overwritten by subclass")