Exemplo n.º 1
0
Arquivo: base.py Projeto: tfiers/sharp
class InputDataMixin:
    downsampler = DownsampleAllRecordings()
    reference_maker = MakeReference()

    input_data_makers = (downsampler, reference_maker)
    # Should be included in the return values of a Task's `requires()`.

    # -------------
    # All channels:

    @property
    @cached
    def multichannel_full(self):
        return self.downsampler.get_multichannel()

    @property
    def multichannel_train(self):
        return TrainTestSplit(self.multichannel_full).signal_train

    @property
    def multichannel_test(self):
        return TrainTestSplit(self.multichannel_full).signal_test

    # -----------------
    # Selected channel:

    @property
    @cached
    def reference_channel_full(self):
        return self.downsampler.get_reference_channel()

    @property
    def reference_channel_train(self):
        return TrainTestSplit(self.reference_channel_full).signal_train

    @property
    def reference_channel_test(self):
        return TrainTestSplit(self.reference_channel_full).signal_test

    # -------------------
    # Reference segments:

    @property
    @cached
    def reference_segs_all(self):
        return self.reference_maker.output().read()

    @property
    def reference_segs_train(self):
        return self._split_refsegs.segments_train

    @property
    def reference_segs_test(self):
        return self._split_refsegs.segments_test

    @property
    def _split_refsegs(self):
        return TrainTestSplit(
            self.reference_channel_full, self.reference_segs_all
        )
Exemplo n.º 2
0
Arquivo: grid.py Projeto: tfiers/sharp
 def ripple_thresholds(self):
     return [
         MakeReference(
             mult_detect_SW=config.mult_detect_SW[0],
             mult_detect_ripple=mult_ripple,
         ).threshold_detect_ripple
         for mult_ripple in config.mult_detect_ripple
     ]
Exemplo n.º 3
0
Arquivo: grid.py Projeto: tfiers/sharp
 def SW_thresholds(self):
     return [
         MakeReference(
             mult_detect_SW=mult_SW,
             mult_detect_ripple=config.mult_detect_ripple[0],
         ).threshold_detect_SW
         for mult_SW in config.mult_detect_SW
     ]
Exemplo n.º 4
0
class ThresholdSweeper(SharpTask):
    """
    Calculate performance of a detector based on its output envelope, for
    different detection thresholds.
    """

    output_root = intermediate_output_dir / "threshold-sweeps"

    envelope_maker: EnvelopeMaker = TaskParameter()
    reference_maker: MakeReference = TaskParameter(default=MakeReference())

    def requires(self):
        return (self.envelope_maker, self.reference_maker)

    def output(self) -> ThresholdSweepFile:
        return ThresholdSweepFile(
            directory=(self.output_root /
                       self.reference_maker.output_filename /
                       self.envelope_maker.output_subdir),
            filename=self.envelope_maker.output_filename,
        )

    @property
    def lockout_time(self) -> float:
        # return percentile(
        #     self.reference_segs_all.duration, config.lockout_percentile
        # )
        return config.lockout_time

    def work(self):
        sweep = ThresholdSweep()
        threshold_range = self.envelope_maker.envelope_test.range
        log.info(f"Evaluating {config.num_thresholds} thresholds, "
                 f"with a lockout time of {1000 * self.lockout_time:.3g} ms.")
        refsegs_all = self.reference_maker.output().read()
        refsegs_test = TrainTestSplit(self.envelope_maker.envelope,
                                      refsegs_all).segments_test
        while len(sweep.thresholds) < config.num_thresholds:
            threshold = sweep.get_next_threshold(threshold_range)
            new_threshold_evaluation = evaluate_threshold(
                self.envelope_maker.envelope_test,
                threshold,
                self.lockout_time,
                refsegs_test,
            )
            sweep.add_threshold_evaluation(new_threshold_evaluation)

        self.output().write(sweep)
Exemplo n.º 5
0
Arquivo: grid.py Projeto: tfiers/sharp
 def data_matrix(self):
     return [
         [
             self.get_data(
                 ThresholdSweeper(
                     envelope_maker=self.envelope_maker,
                     reference_maker=MakeReference(
                         mult_detect_ripple=ripple, mult_detect_SW=SW
                     ),
                 )
                 .output()
                 .read()
             )
             for ripple in config.mult_detect_ripple
         ]
         for SW in config.mult_detect_SW
     ]
Exemplo n.º 6
0
from sharp.config.load import config
from sharp.data.hardcoded.style import blue, pink
from sharp.data.types.evaluation.sweep import ThresholdSweep
from sharp.data.types.evaluation.threshold import ThresholdEvaluation
from sharp.tasks.evaluate.sweep import ThresholdSweeper
from sharp.tasks.neuralnet.apply import ApplyRNN
from sharp.tasks.plot.base import FigureMaker
from sharp.tasks.signal.online_bpf import ApplyOnlineBPF
from sharp.tasks.signal.reference import MakeReference

output_dir = FigureMaker.output_dir / "minipaper"

rm = MakeReference(
    mult_detect_SW=config.mult_detect_SW[3],
    mult_detect_ripple=config.mult_detect_ripple[3],
    # mult_detect_ripple=config.mult_detect_ripple[4],
)

sweeper_rnn = ThresholdSweeper(reference_maker=rm, envelope_maker=ApplyRNN())
sweeper_bpf = ThresholdSweeper(reference_maker=rm,
                               envelope_maker=ApplyOnlineBPF())
color_rnn = pink
color_bpf = blue

sweepers = (sweeper_bpf, sweeper_rnn)
colors = (color_bpf, color_rnn)
labels = ("Band-pass filter", "Recurrent neural net")


def get_sweeps() -> Tuple[ThresholdSweep, ...]:
Exemplo n.º 7
0
class PlotReferenceMaker(TimeRangesPlotter):
    """
    Plots the input signal, the reference maker filter output, and the
    reference segments. Does this only for the evaluation slice.
    """

    selected_time_ranges_only = False
    reference_channel_only = False
    full_range_scalebars = True
    output_dir = TimeRangesPlotter.output_dir / "reference"
    window_size = 2
    figwidth = 1

    reference_makers = [MakeReference(**args) for args in MakeReference.args]
    rm0 = reference_makers[0]

    def requires(self):
        return super().requires() + tuple(self.reference_makers)

    @property
    def extra_signals(self):
        envelopes = (
            self.rm0.ripple_envelope,
            self.rm0.SW_envelope,
            self.rm0.toppyr_envelope,
            self.rm0.sr_envelope,
        )
        return [TrainTestSplit(env).signal_test for env in envelopes]

    def post_plot(self, time_range, input_ax, extra_axes):
        self.add_SWR_segs(input_ax)
        self.add_SW_segs(extra_axes[1])
        self.add_ripple_segs(extra_axes[0])
        # ax = extra_axes[0]
        # ax.hlines(self.reference_maker.ripple_threshold_high, *time_range)
        # ax.hlines(
        #     self.reference_maker.ripple_threshold_low,
        #     *time_range,
        #     linestyles="dashed"
        # )

    def add_SWR_segs(self, ax):
        for rm in self.reference_makers:
            self.add_segs(ax, rm.output().read())

    def add_SW_segs(self, ax):
        mr = config.mult_detect_ripple[-1]
        for ms in config.mult_detect_SW:
            rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr)
            self.add_segs(ax, rm.calc_SW_segments())

    def add_ripple_segs(self, ax):
        ms = config.mult_detect_SW[-1]
        for mr in config.mult_detect_ripple:
            rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr)
            self.add_segs(ax, rm.calc_ripple_segments())

    def add_segs(self, ax, segs):
        env = self.rm0.ripple_envelope
        segs_test = TrainTestSplit(env, segs).segments_test
        add_segments(ax, segs_test, alpha=0.1)
Exemplo n.º 8
0
 def add_ripple_segs(self, ax):
     ms = config.mult_detect_SW[-1]
     for mr in config.mult_detect_ripple:
         rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr)
         self.add_segs(ax, rm.calc_ripple_segments())
Exemplo n.º 9
0
 def add_SW_segs(self, ax):
     mr = config.mult_detect_ripple[-1]
     for ms in config.mult_detect_SW:
         rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr)
         self.add_segs(ax, rm.calc_SW_segments())
Exemplo n.º 10
0
Arquivo: grid.py Projeto: tfiers/sharp
class PaperGridPlotter(FigureMaker):
    envelope_maker: EnvelopeMaker = TaskParameter()
    reference_makers = [
        MakeReference(**args) for args in MakeReference.args
    ]
    cmap = get_cmap("viridis")
    colorbar_label: str = ...
    fstring: str = ...

    @property
    def sweepers(self):
        return [
            ThresholdSweeper(
                envelope_maker=self.envelope_maker, reference_maker=rm
            )
            for rm in self.reference_makers
        ]

    def requires(self):
        return self.sweepers

    def output(self):
        return (self.output_grid, self.output_colorbar)

    @property
    def output_grid(self):
        return FigureTarget(output_dir, self.filename)

    @property
    def output_colorbar(self):
        return FigureTarget(output_dir / "cbar", self.filename)

    @property
    def filename(self):
        return f"{self.__class__.__name__} -- {self.envelope_maker.title}"

    def work(self):
        log.info(
            (
                self.filename,
                mean(self.data_matrix),
                percentile(self.data_matrix, [25, 50, 100]),
            )
        )
        self.plot_grid()
        self.plot_colorbar()

    def plot_colorbar(self):
        fig, ax = subplots(figsize=paperfig(0.42, 0.16))
        cbar = ColorbarBase(
            ax=ax,
            orientation="horizontal",
            label=self.colorbar_label,
            norm=self.norm,
            cmap=self.cmap,
            extend="both",
            format=StrMethodFormatter(self.fstring),
        )
        fig.tight_layout()
        self.output_colorbar.write(fig)

    def plot_grid(self):
        fig, ax = subplots(figsize=paperfig(0.55, 0.55))
        ax.imshow(
            self.data_matrix,
            cmap=self.cmap,
            origin="lower",
            aspect="auto",
            norm=self.norm,
        )
        ax.set_xlabel("Ripple definition (μV)")
        ax.set_ylabel("Sharp wave definition (μV)")
        num_ripple = len(self.ripple_thresholds)
        num_SW = len(self.SW_thresholds)
        ax.set_xticks(arange(num_ripple)[::2])
        ax.set_yticks(arange(num_SW)[::2])
        ax.set_xticklabels([f"{x:.0f}" for x in self.ripple_thresholds][::2])
        ax.set_yticklabels([f"{x:.0f}" for x in self.SW_thresholds][::2])
        for SW_ix in range(num_SW):
            for ripple_ix in range(num_ripple):
                value = self.data_matrix[SW_ix][ripple_ix]
                text = self.fstring.format(x=value)
                ax.text(
                    ripple_ix,
                    SW_ix,
                    text,
                    ha="center",
                    va="center",
                    color=self.text_color(value),
                    size="smaller",
                )
        ax.grid(False)
        fig.tight_layout()
        self.output_grid.write(fig)

    @property
    @cached
    def SW_thresholds(self):
        return [
            MakeReference(
                mult_detect_SW=mult_SW,
                mult_detect_ripple=config.mult_detect_ripple[0],
            ).threshold_detect_SW
            for mult_SW in config.mult_detect_SW
        ]

    @property
    @cached
    def ripple_thresholds(self):
        return [
            MakeReference(
                mult_detect_SW=config.mult_detect_SW[0],
                mult_detect_ripple=mult_ripple,
            ).threshold_detect_ripple
            for mult_ripple in config.mult_detect_ripple
        ]

    @property
    @cached
    def data_matrix(self):
        return [
            [
                self.get_data(
                    ThresholdSweeper(
                        envelope_maker=self.envelope_maker,
                        reference_maker=MakeReference(
                            mult_detect_ripple=ripple, mult_detect_SW=SW
                        ),
                    )
                    .output()
                    .read()
                )
                for ripple in config.mult_detect_ripple
            ]
            for SW in config.mult_detect_SW
        ]

    def get_data(self, sweep: ThresholdSweep) -> float:
        ...

    @property
    @cached
    def norm(self):
        return Normalize(*self.cmap_range)

    def text_color(self, data_value):
        # transform data_value in [vmin, vmax] to [0, 1]
        fraction = self.norm(data_value)
        # convert value in [0, 1] to an RGBA value
        facecolor = self.cmap(fraction)
        _, lightness, _ = rgb_to_hls(*facecolor[:3])
        if lightness > 0.4:
            text_color = "0.2"
        else:
            text_color = "0.8"
        return text_color