class InputDataMixin: downsampler = DownsampleAllRecordings() reference_maker = MakeReference() input_data_makers = (downsampler, reference_maker) # Should be included in the return values of a Task's `requires()`. # ------------- # All channels: @property @cached def multichannel_full(self): return self.downsampler.get_multichannel() @property def multichannel_train(self): return TrainTestSplit(self.multichannel_full).signal_train @property def multichannel_test(self): return TrainTestSplit(self.multichannel_full).signal_test # ----------------- # Selected channel: @property @cached def reference_channel_full(self): return self.downsampler.get_reference_channel() @property def reference_channel_train(self): return TrainTestSplit(self.reference_channel_full).signal_train @property def reference_channel_test(self): return TrainTestSplit(self.reference_channel_full).signal_test # ------------------- # Reference segments: @property @cached def reference_segs_all(self): return self.reference_maker.output().read() @property def reference_segs_train(self): return self._split_refsegs.segments_train @property def reference_segs_test(self): return self._split_refsegs.segments_test @property def _split_refsegs(self): return TrainTestSplit( self.reference_channel_full, self.reference_segs_all )
def ripple_thresholds(self): return [ MakeReference( mult_detect_SW=config.mult_detect_SW[0], mult_detect_ripple=mult_ripple, ).threshold_detect_ripple for mult_ripple in config.mult_detect_ripple ]
def SW_thresholds(self): return [ MakeReference( mult_detect_SW=mult_SW, mult_detect_ripple=config.mult_detect_ripple[0], ).threshold_detect_SW for mult_SW in config.mult_detect_SW ]
class ThresholdSweeper(SharpTask): """ Calculate performance of a detector based on its output envelope, for different detection thresholds. """ output_root = intermediate_output_dir / "threshold-sweeps" envelope_maker: EnvelopeMaker = TaskParameter() reference_maker: MakeReference = TaskParameter(default=MakeReference()) def requires(self): return (self.envelope_maker, self.reference_maker) def output(self) -> ThresholdSweepFile: return ThresholdSweepFile( directory=(self.output_root / self.reference_maker.output_filename / self.envelope_maker.output_subdir), filename=self.envelope_maker.output_filename, ) @property def lockout_time(self) -> float: # return percentile( # self.reference_segs_all.duration, config.lockout_percentile # ) return config.lockout_time def work(self): sweep = ThresholdSweep() threshold_range = self.envelope_maker.envelope_test.range log.info(f"Evaluating {config.num_thresholds} thresholds, " f"with a lockout time of {1000 * self.lockout_time:.3g} ms.") refsegs_all = self.reference_maker.output().read() refsegs_test = TrainTestSplit(self.envelope_maker.envelope, refsegs_all).segments_test while len(sweep.thresholds) < config.num_thresholds: threshold = sweep.get_next_threshold(threshold_range) new_threshold_evaluation = evaluate_threshold( self.envelope_maker.envelope_test, threshold, self.lockout_time, refsegs_test, ) sweep.add_threshold_evaluation(new_threshold_evaluation) self.output().write(sweep)
def data_matrix(self): return [ [ self.get_data( ThresholdSweeper( envelope_maker=self.envelope_maker, reference_maker=MakeReference( mult_detect_ripple=ripple, mult_detect_SW=SW ), ) .output() .read() ) for ripple in config.mult_detect_ripple ] for SW in config.mult_detect_SW ]
from sharp.config.load import config from sharp.data.hardcoded.style import blue, pink from sharp.data.types.evaluation.sweep import ThresholdSweep from sharp.data.types.evaluation.threshold import ThresholdEvaluation from sharp.tasks.evaluate.sweep import ThresholdSweeper from sharp.tasks.neuralnet.apply import ApplyRNN from sharp.tasks.plot.base import FigureMaker from sharp.tasks.signal.online_bpf import ApplyOnlineBPF from sharp.tasks.signal.reference import MakeReference output_dir = FigureMaker.output_dir / "minipaper" rm = MakeReference( mult_detect_SW=config.mult_detect_SW[3], mult_detect_ripple=config.mult_detect_ripple[3], # mult_detect_ripple=config.mult_detect_ripple[4], ) sweeper_rnn = ThresholdSweeper(reference_maker=rm, envelope_maker=ApplyRNN()) sweeper_bpf = ThresholdSweeper(reference_maker=rm, envelope_maker=ApplyOnlineBPF()) color_rnn = pink color_bpf = blue sweepers = (sweeper_bpf, sweeper_rnn) colors = (color_bpf, color_rnn) labels = ("Band-pass filter", "Recurrent neural net") def get_sweeps() -> Tuple[ThresholdSweep, ...]:
class PlotReferenceMaker(TimeRangesPlotter): """ Plots the input signal, the reference maker filter output, and the reference segments. Does this only for the evaluation slice. """ selected_time_ranges_only = False reference_channel_only = False full_range_scalebars = True output_dir = TimeRangesPlotter.output_dir / "reference" window_size = 2 figwidth = 1 reference_makers = [MakeReference(**args) for args in MakeReference.args] rm0 = reference_makers[0] def requires(self): return super().requires() + tuple(self.reference_makers) @property def extra_signals(self): envelopes = ( self.rm0.ripple_envelope, self.rm0.SW_envelope, self.rm0.toppyr_envelope, self.rm0.sr_envelope, ) return [TrainTestSplit(env).signal_test for env in envelopes] def post_plot(self, time_range, input_ax, extra_axes): self.add_SWR_segs(input_ax) self.add_SW_segs(extra_axes[1]) self.add_ripple_segs(extra_axes[0]) # ax = extra_axes[0] # ax.hlines(self.reference_maker.ripple_threshold_high, *time_range) # ax.hlines( # self.reference_maker.ripple_threshold_low, # *time_range, # linestyles="dashed" # ) def add_SWR_segs(self, ax): for rm in self.reference_makers: self.add_segs(ax, rm.output().read()) def add_SW_segs(self, ax): mr = config.mult_detect_ripple[-1] for ms in config.mult_detect_SW: rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr) self.add_segs(ax, rm.calc_SW_segments()) def add_ripple_segs(self, ax): ms = config.mult_detect_SW[-1] for mr in config.mult_detect_ripple: rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr) self.add_segs(ax, rm.calc_ripple_segments()) def add_segs(self, ax, segs): env = self.rm0.ripple_envelope segs_test = TrainTestSplit(env, segs).segments_test add_segments(ax, segs_test, alpha=0.1)
def add_ripple_segs(self, ax): ms = config.mult_detect_SW[-1] for mr in config.mult_detect_ripple: rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr) self.add_segs(ax, rm.calc_ripple_segments())
def add_SW_segs(self, ax): mr = config.mult_detect_ripple[-1] for ms in config.mult_detect_SW: rm = MakeReference(mult_detect_SW=ms, mult_detect_ripple=mr) self.add_segs(ax, rm.calc_SW_segments())
class PaperGridPlotter(FigureMaker): envelope_maker: EnvelopeMaker = TaskParameter() reference_makers = [ MakeReference(**args) for args in MakeReference.args ] cmap = get_cmap("viridis") colorbar_label: str = ... fstring: str = ... @property def sweepers(self): return [ ThresholdSweeper( envelope_maker=self.envelope_maker, reference_maker=rm ) for rm in self.reference_makers ] def requires(self): return self.sweepers def output(self): return (self.output_grid, self.output_colorbar) @property def output_grid(self): return FigureTarget(output_dir, self.filename) @property def output_colorbar(self): return FigureTarget(output_dir / "cbar", self.filename) @property def filename(self): return f"{self.__class__.__name__} -- {self.envelope_maker.title}" def work(self): log.info( ( self.filename, mean(self.data_matrix), percentile(self.data_matrix, [25, 50, 100]), ) ) self.plot_grid() self.plot_colorbar() def plot_colorbar(self): fig, ax = subplots(figsize=paperfig(0.42, 0.16)) cbar = ColorbarBase( ax=ax, orientation="horizontal", label=self.colorbar_label, norm=self.norm, cmap=self.cmap, extend="both", format=StrMethodFormatter(self.fstring), ) fig.tight_layout() self.output_colorbar.write(fig) def plot_grid(self): fig, ax = subplots(figsize=paperfig(0.55, 0.55)) ax.imshow( self.data_matrix, cmap=self.cmap, origin="lower", aspect="auto", norm=self.norm, ) ax.set_xlabel("Ripple definition (μV)") ax.set_ylabel("Sharp wave definition (μV)") num_ripple = len(self.ripple_thresholds) num_SW = len(self.SW_thresholds) ax.set_xticks(arange(num_ripple)[::2]) ax.set_yticks(arange(num_SW)[::2]) ax.set_xticklabels([f"{x:.0f}" for x in self.ripple_thresholds][::2]) ax.set_yticklabels([f"{x:.0f}" for x in self.SW_thresholds][::2]) for SW_ix in range(num_SW): for ripple_ix in range(num_ripple): value = self.data_matrix[SW_ix][ripple_ix] text = self.fstring.format(x=value) ax.text( ripple_ix, SW_ix, text, ha="center", va="center", color=self.text_color(value), size="smaller", ) ax.grid(False) fig.tight_layout() self.output_grid.write(fig) @property @cached def SW_thresholds(self): return [ MakeReference( mult_detect_SW=mult_SW, mult_detect_ripple=config.mult_detect_ripple[0], ).threshold_detect_SW for mult_SW in config.mult_detect_SW ] @property @cached def ripple_thresholds(self): return [ MakeReference( mult_detect_SW=config.mult_detect_SW[0], mult_detect_ripple=mult_ripple, ).threshold_detect_ripple for mult_ripple in config.mult_detect_ripple ] @property @cached def data_matrix(self): return [ [ self.get_data( ThresholdSweeper( envelope_maker=self.envelope_maker, reference_maker=MakeReference( mult_detect_ripple=ripple, mult_detect_SW=SW ), ) .output() .read() ) for ripple in config.mult_detect_ripple ] for SW in config.mult_detect_SW ] def get_data(self, sweep: ThresholdSweep) -> float: ... @property @cached def norm(self): return Normalize(*self.cmap_range) def text_color(self, data_value): # transform data_value in [vmin, vmax] to [0, 1] fraction = self.norm(data_value) # convert value in [0, 1] to an RGBA value facecolor = self.cmap(fraction) _, lightness, _ = rgb_to_hls(*facecolor[:3]) if lightness > 0.4: text_color = "0.2" else: text_color = "0.8" return text_color