示例#1
0
def get_cmt_options(
        context: strax.Context) -> ty.Dict[str, ty.Dict[str, tuple]]:
    """
    Function which loops over all plugin configs and returns dictionary
    with option name as key and a nested dict of CMT correction name and strax option as values.

    :param context: Context with registered plugins.
    """

    cmt_options = {}
    runid_test_str = '0000'

    for data_type, plugin in context._plugin_class_registry.items():
        for option_key, option in plugin.takes_config.items():
            if option_key in cmt_options:
                # let's not do work twice if needed by > 1 plugin
                continue

            if (option_key in context.config
                    and is_cmt_option(context.config[option_key])):
                opt = context.config[option_key]
            elif is_cmt_option(option.default):
                opt = option.default
            else:
                continue

            # check if it's a URLConfig
            if isinstance(opt, str) and 'cmt://' in opt:
                before_cmt, cmt, after_cmt = opt.partition('cmt://')
                p = context._get_plugins((data_type, ),
                                         runid_test_str)[data_type]
                context._set_plugin_config(p, runid_test_str, tolerant=False)
                del p.run_id

                p.config[option_key] = after_cmt
                try:
                    correction_name = getattr(p, option_key)
                except AttributeError:
                    # make sure the correction name does not depend on runid
                    raise RuntimeError(
                        "Correction names should not depend on runids! "
                        f"Please check your option for {option_key}")

                # if there is no other protocol being called before cmt,
                # we will get a string back including the query part
                if option.QUERY_SEP in correction_name:
                    correction_name, _ = option.split_url_kwargs(
                        correction_name)
                cmt_options[option_key] = {
                    'correction': correction_name,
                    'strax_option': opt,
                }

            else:
                cmt_options[option_key] = {
                    'correction': opt[0],
                    'strax_option': opt,
                }
    return cmt_options
示例#2
0
def plot_single_event(context: strax.Context,
                      run_id,
                      events,
                      event_number=None,
                      **kwargs):
    """
    Wrapper for event_display

    :param context: strax.context
    :param run_id: run id
    :param events: dataframe / numpy array of events. Should either be
        length 1 or the event_number argument should be provided
    :param event_number: (optional) int, if provided, only show this
        event number
    :param kwargs: kwargs for events_display
    :return: see events_display
    """
    if event_number is not None:
        events = events[events['event_number'] == event_number]
    if len(events) > 1 or len(events) == 0:
        raise ValueError(f'Make sure to provide an event number or a single '
                         f'event. Got {len(events)} events')

    return context.event_display(run_id,
                                 time_range=(events[0]['time'],
                                             events[0]['endtime']),
                                 **kwargs)
示例#3
0
 def try_load(self, st: strax.Context, target: str):
     try:
         rr = st.get_array(self.run_id, target)
     except strax.DataNotAvailable as data_error:
         message = (f'Could not find '
                    f'{st.key_for(self.run_id, target)} '
                    f'with the following frontends\n')
         for sf in st.storage:
             message += f'\t{sf}\n'
         raise strax.DataNotAvailable(message) from data_error
     return rr
示例#4
0
        def wrapped_f(context: strax.Context, run_id: str, **kwargs):
            # Validate arguments
            known_kwargs = (
                'time_range seconds_range time_within time_selection '
                'ignore_time_warning '
                'selection_str t_reference to_pe config').split()
            for k in kwargs:
                if k not in known_kwargs and k not in parameters:
                    # Python itself also raises TypeError for invalid kwargs
                    raise TypeError(f"Unknown argument {k} for {f.__name__}")

            if 'config' in kwargs:
                context = context.new_context(config=kwargs['config'])
            if 'config' in parameters:
                kwargs['config'] = context.config

            # Say magic words to enable holoviews
            if hv_bokeh:
                global _hv_bokeh_initialized
                if not _hv_bokeh_initialized:
                    import holoviews
                    holoviews.extension('bokeh')
                    _hv_bokeh_initialized = True

            # TODO: This is a placeholder until the corrections system
            # is more fully developed
            if 'to_pe' in parameters and 'to_pe' not in kwargs:
                kwargs['to_pe'] = straxen.get_to_pe(
                    run_id, context.config['gain_model'],
                    context.config['n_tpc_pmts'])

            # Prepare selection arguments
            kwargs['time_range'] = context.to_absolute_time_range(
                run_id,
                targets=requires,
                **{
                    k: kwargs.get(k)
                    for k in ('time_range seconds_range time_within'.split())
                })
            kwargs.setdefault('time_selection', default_time_selection)
            kwargs.setdefault('selection_str', None)

            kwargs['t_reference'] = context.estimate_run_start(
                run_id, requires)

            if warn_beyond_sec is not None and not kwargs.get(
                    'ignore_time_warning'):
                tr = kwargs['time_range']
                if tr is None:
                    sec_requested = float('inf')
                else:
                    sec_requested = (tr[1] - tr[0]) / int(1e9)
                if sec_requested > warn_beyond_sec:
                    tr_str = "the entire run" if tr is None else f"{sec_requested} seconds"
                    raise ValueError(
                        f"The author of this mini analysis recommends "
                        f"not requesting more than {warn_beyond_sec} seconds. "
                        f"You are requesting {tr_str}. If you wish to proceed, "
                        "pass ignore_time_warning = True.")

            # Load required data, if any
            if len(requires):
                deps_by_kind = strax.group_by_kind(requires, context=context)
                for dkind, dtypes in deps_by_kind.items():
                    if dkind in kwargs:
                        # Already have data, just apply cuts
                        kwargs[dkind] = context.apply_selection(
                            kwargs[dkind],
                            selection_str=kwargs['selection_str'],
                            time_range=kwargs['time_range'],
                            time_selection=kwargs['time_selection'])
                    else:
                        kwargs[dkind] = context.get_array(
                            run_id,
                            dtypes,
                            selection_str=kwargs['selection_str'],
                            time_range=kwargs['time_range'],
                            time_selection=kwargs['time_selection'],
                            # Arguments for new context, if needed
                            config=kwargs.get('config'),
                            register=kwargs.get('register'),
                            storage=kwargs.get('storage', tuple()))

                # If user did not give time kwargs, but the function expects
                # a time_range, try to add one based on the time range of the data
                base_dkind = list(deps_by_kind.keys())[0]
                x = kwargs[base_dkind]
                if len(x) and kwargs.get('time_range') is None:
                    x0 = x.iloc[0] if isinstance(x, pd.DataFrame) else x[0]
                    try:
                        kwargs.setdefault('time_range',
                                          (x0['time'], strax.endtime(x).max()))

                    except AttributeError:
                        # If x is a holoviews dataset, this will fail.
                        pass

            if 'seconds_range' in parameters:
                if kwargs.get('time_range') is None:
                    scr = None
                else:
                    scr = tuple([(t - kwargs['t_reference']) / int(1e9)
                                 for t in kwargs['time_range']])
                kwargs.setdefault('seconds_range', scr)

            kwargs.setdefault('run_id', run_id)
            kwargs.setdefault('context', context)

            if 'kwargs' in parameters:
                # Likely this will be passed to another mini-analysis
                to_pass = kwargs
                # Do not pass time_range and seconds_range both (unless explicitly requested)
                # strax does not like that
                if 'seconds_range' in to_pass and not 'seconds_range' in parameters:
                    del to_pass['seconds_range']
                if 'time_within' in to_pass and not 'time_within' in parameters:
                    del to_pass['time_within']
            else:
                # Pass only arguments the function wants
                to_pass = {k: v for k, v in kwargs.items() if k in parameters}
            return f(**to_pass)
示例#5
0
def apply_cmt_version(context: strax.Context, cmt_global_version: str) -> None:
    """Sets all the relevant correction variables
    :param cmt_global_version: A specific CMT global version, or 'latest' to get the newest one
    :returns None
    """
    local_versions = get_cmt_local_versions(cmt_global_version)

    # get the position algorithm we are using
    # I feel like this should be easier...
    posrec_option = 'default_reconstruction_algorithm'
    if posrec_option in context.config:
        posrec_algo = context.config[posrec_option]
    else:
        posrec_algo = context._plugin_class_registry['peak_positions'].takes_config[posrec_option].default

    cmt_options = straxen.get_corrections.get_cmt_options(context)

    # catch here global versions that are not compatible with this straxen version
    # this happens if a new correction was added to CMT that was not used in a fixed version
    # we want this error to occur in order to keep fixed global versions
    cmt_config = dict()
    failed_keys = []

    for option, option_info in cmt_options.items():
        # name of the CMT correction, this is not always equal to the strax option
        correction_name = option_info['correction']
        # actual config option
        # this could be either a CMT tuple or a URLConfig
        value = option_info['strax_option']

        # might need to modify correction name to include position reconstruction algo
        # this is a bit of a mess, but posrec configs are treated differently in the tuples
        # URL configs should already include the posrec suffix
        # (it's real mess -- we should drop tuple configs)
        if correction_name in posrec_corrections_basenames:
            correction_name += f"_{posrec_algo}"

        # now see if our correction is in our local_versions dict
        if correction_name in local_versions:
            if isinstance(value, str) and 'cmt://' in value:
                new_value = replace_url_version(value, local_versions[correction_name])
            # if it is a tuple, make a new tuple
            else:
                new_value = (value[0], local_versions[correction_name], value[2])
        else:
            if correction_name not in failed_keys:
                failed_keys.append(correction_name)
            continue

        cmt_config[option] = new_value

    if len(failed_keys):
        failed_keys = ', '.join(failed_keys)
        msg = f"CMT version {cmt_global_version} is not compatible with this straxen version! " \
              f"CMT {cmt_global_version} is missing these corrections: {failed_keys}"

        # only raise a warning if we are working with the online context
        if cmt_global_version == "global_ONLINE":
            warnings.warn(msg, UserWarning)
        else:
            raise CMTVersionError(msg)

    context.set_config(cmt_config)
示例#6
0
def compare_outcomes(
    st: strax.Context,
    data: np.ndarray,
    st_alt: ty.Optional[strax.Context] = None,
    data_alt: ty.Optional[np.ndarray] = None,
    match_fuzz: int = 500,
    plot_fuzz: int = 500,
    max_peaks: int = 10,
    default_label: str = 'default',
    custom_label: str = 'custom',
    fig_dir: ty.Union[None, str] = None,
    show: bool = True,
    randomize: bool = True,
    different_by: ty.Optional[ty.Union[bool, str]] = 'acceptance_fraction',
    run_id: ty.Union[None, str] = None,
    raw: bool = False,
    pulse: bool = True,
) -> None:
    """
    Compare the outcomes of two contexts with one another. In order to
    allow for selections, we need to pass the data as second and third
    argument respectively.

    :param st: the context of the current master, to compare
        with st_custom
    :param data: the  data consistent with the default
        context, can be cut to select certain data
    :param st_alt: context wherewith to compare st_default
    :param data_alt: the data with the custom context, should be
        same length as truth_vs_default
    :param match_fuzz: Extend loading peaks this many ns to allow for
        small shifts in reconstruction. Will extend the time range left
        and right
    :param plot_fuzz: Make the plot slightly larger with this many ns
        for readability
    :param max_peaks: max number of peaks to be shown. Set to  1 for
        plotting a singe peak.
    :param default_label: How to label the default reconstruction
    :param custom_label:How to label the custom reconstruction
    :param fig_dir: Where to save figures (if None, don't save)
    :param show: show the figures or not.
    :param randomize: randomly order peaks to get a random sample of
        <max_peaks> every time
    :param different_by: Field to filter waveforms by. Only show
        waveforms where this field is different in data. If False, plot
        any waveforms from the two data sets.
    :param run_id: Optional argument in case run_id is not a field in
        the data.
    :param raw: include raw-records-trace
    :param pulse: plot raw-record traces.
    :return: None
    """

    if (st_alt is None) != (data_alt is None):
        raise RuntimeError(
            'Both st_alt and data_alt should be specified simultaneously')
    _plot_difference = st_alt is not None

    if _plot_difference:
        _check_args(data, data_alt, run_id)
        peaks_idx = _get_peak_idxs_from_args(data, randomize, data_alt,
                                             different_by)
    else:
        _check_args(data, None, run_id)
        peaks_idx = _get_peak_idxs_from_args(data, randomize)

    for peak_i in tqdm(peaks_idx[:max_peaks]):
        try:
            if 'run_id' in data.dtype.names:
                run_mask = data['run_id'] == data[peak_i]['run_id']
                run_id = data[peak_i]['run_id']
            else:
                run_mask = np.ones(len(data), dtype=np.bool_)
            t_range, start_end, xlim = _get_time_ranges(
                data, peak_i, match_fuzz, plot_fuzz)

            axes = iter(
                _get_axes_for_compare_plot(2 + int(_plot_difference) +
                                           int(raw) + int(pulse)))

            plt.sca(next(axes))
            _plot_truth(data[run_mask], start_end, t_range)

            if raw:
                plt.sca(next(axes))
                st.plot_records_matrix(
                    run_id,
                    raw=True,
                    single_figure=False,
                    time_range=t_range,
                    time_selection='touching',
                )
                for t in t_range:
                    axvline(t / 1e9)

            if pulse:
                plt.sca(next(axes))
                rr_simple_plot(st, run_id, t_range)

            plt.sca(next(axes))
            _plot_peak(
                st,
                data,
                default_label,
                peak_i,
                t_range,
                xlim,
                run_id,
                label_x_axis=not _plot_difference,
            )

            if _plot_difference:
                plt.sca(next(axes))
                _plot_peak(
                    st_alt,
                    data_alt,
                    custom_label,
                    peak_i,
                    t_range,
                    xlim,
                    run_id,
                    label_x_axis=True,
                )

            _save_and_show('example_wf_diff', fig_dir, show, peak_i)
        except (ValueError, RuntimeError) as e:
            print(f'Error making {peak_i}: {type(e)}, {e}')
            plt.show()
示例#7
0
def plot_wf(st: strax.Context,
            containers,
            run_id, plot_log=True, plot_extension=0, hit_pattern=True,
            timestamp=True, time_fmt="%d-%b-%Y (%H:%M:%S)",
            **kwargs):
    """
    Combined waveform plot
    :param st: strax.Context
    :param containers: peaks/records/events where from we want to plot
        all the peaks that are within it's time range +- the
        plot_extension. For example, you can provide three adjacent
        peaks and plot them in a single figure.
    :param run_id: run_id of the containers
    :param plot_log: Plot the y-scale of the wf in log-space
    :param plot_extension: include this much nanoseconds around the
        containers (can be scalar or list of (-left_extension,
        right_extension).
    :param hit_pattern: include the hit-pattern in the wf
    :param timestamp: print the timestamp to the plot
    :param time_fmt: format fo the timestamp (datetime.strftime format)
    :param kwargs: kwargs for plot_peaks
    """

    if not isinstance(run_id, str):
        raise ValueError(f'Insert single run_id, not {run_id}')

    p = containers  # usually peaks
    run_start, _ = st.estimate_run_start_and_end(run_id)
    t_range = np.array([p['time'].min(), strax.endtime(p).max()])

    # Extend the time range if needed.
    if not np.iterable(plot_extension):
        t_range += np.array([-plot_extension, plot_extension])
    elif len(plot_extension) == 2:
        if not plot_extension[0] < 0:
            warnings.warn('Left extension is positive (i.e. later than start '
                          'of container).')
        t_range += plot_extension
    else:
        raise ValueError('Wrong dimensions for plot_extension. Use scalar or '
                         'object of len( ) == 2')
    t_range -= run_start
    t_range = t_range / 10 ** 9
    t_range = np.clip(t_range, 0, np.inf)


    if hit_pattern:
        plt.figure(figsize=(14, 11))
        plt.subplot(212)
    else:
        plt.figure(figsize=(14, 5))
    # Plot the wf
    plot_peaks(st, run_id, seconds_range=t_range, single_figure=False, **kwargs)

    if timestamp:
        _ax = plt.gca()
        t_stamp = datetime.datetime.fromtimestamp(
            containers['time'].min() / 10 ** 9).strftime(time_fmt)
        _ax.text(0.975, 0.925, t_stamp,
                 horizontalalignment='right',
                 verticalalignment='top',
                 transform=_ax.transAxes)
    # Select the additional two panels to show the top and bottom arrays
    if hit_pattern:
        axes = plt.subplot(221), plt.subplot(222)
        plot_hit_pattern(st, run_id,
                         seconds_range=t_range,
                         axes=axes,
                         vmin=1 if plot_log else None,
                         log_scale=plot_log,
                         label='Area per channel [PE]')
示例#8
0
        def wrapped_f(context: strax.Context, run_id: str, **kwargs):
            # Validate arguments
            known_kwargs = (
                'time_range seconds_range time_within time_selection '
                'selection_str t_reference to_pe config').split()
            for k in kwargs:
                if k not in known_kwargs and k not in parameters:
                    # Python itself also raises TypeError for invalid kwargs
                    raise TypeError(f"Unknown argument {k} for {f.__name__}")

            if 'config' in kwargs:
                context = context.new_context(config=kwargs['config'])

            # Say magic words to enable holoviews
            if hv_bokeh:
                global _hv_bokeh_initialized
                if not _hv_bokeh_initialized:
                    hv.extension('bokeh')
                    _hv_bokeh_initialized = True

            # TODO: This is a placeholder until the corrections system
            # is more fully developed
            if 'to_pe' in parameters and 'to_pe' not in kwargs:
                kwargs['to_pe'] = nEXO_strax.get_to_pe(
                    run_id,
                    'https://raw.githubusercontent.com/XENONnT/'
                    'strax_auxiliary_files/master/to_pe.npy')

            # Prepare selection arguments
            kwargs['time_range'] = context.to_absolute_time_range(
                run_id,
                targets=requires,
                **{k: kwargs.get(k)
                   for k in ('time_range seconds_range time_within'.split())})
            kwargs.setdefault('time_selection', 'fully_contained')
            kwargs.setdefault('selection_str', None)

            if ('t_reference' in parameters
                    and kwargs.get('t_reference') is None):
                kwargs['t_reference'] = context.estimate_run_start(
                    run_id, requires)

            # Load required data
            deps_by_kind = strax.group_by_kind(
                requires, context=context, require_time=False)
            for dkind, dtypes in deps_by_kind.items():
                if dkind in kwargs:
                    # Already have data, just apply cuts
                    kwargs[dkind] = context.apply_selection(
                        kwargs[dkind],
                        selection_str=kwargs['selection_str'],
                        time_range=kwargs['time_range'],
                        time_selection=kwargs['time_selection'])
                else:
                    kwargs[dkind] = context.get_array(
                        run_id,
                        dtypes,
                        selection_str=kwargs['selection_str'],
                        time_range=kwargs['time_range'],
                        time_selection=kwargs['time_selection'])

            # If user did not give time kwargs, but the function expects
            # a time_range, add them based on the time range of the data
            # (if there is no data, otherwise give (NaN, NaN))
            if kwargs.get('time_range') is None:
                base_dkind = list(deps_by_kind.keys())[0]
                x = kwargs[base_dkind]
                x0 = x.iloc[0] if isinstance(x, pd.DataFrame) else x[0]
                kwargs['time_range'] = (
                    (x0['time'], strax.endtime(x).max()) if len(x)
                    else (float('nan'), float('nan')))

            # Pass only the arguments the function wants
            to_pass = dict()
            for k in parameters:
                if k == 'run_id':
                    to_pass['run_id'] = run_id
                elif k == 'context':
                    to_pass['context'] = context
                elif k in kwargs:
                    to_pass[k] = kwargs[k]
                # If we get here, let's hope the function defines a default...

            return f(**to_pass)