コード例 #1
0
    def register(self, plugin_class, provides=None):
        """Register plugin_class as provider for data types in provides.
        :param plugin_class: class inheriting from StraxPlugin
        :param provides: list of data types which this plugin provides.

        Plugins always register for the data type specified in the .provide
        class attribute. If such is not available, we will construct one from
        the class name (CamelCase -> snake_case)

        Returns plugin_class (so this can be used as a decorator)
        """
        if isinstance(plugin_class, (tuple, list)) and provides is None:
            # shortcut for multiple registration
            # TODO: document
            for x in plugin_class:
                self.register(x)
            return

        if not hasattr(plugin_class, 'provides'):
            # No output name specified: construct one from the class name
            snake_name = strax.camel_to_snake(plugin_class.__name__)
            plugin_class.provides = snake_name

        if provides is not None:
            provides += [plugin_class.provides]
        else:
            provides = [plugin_class.provides]

        for p in provides:
            self._plugin_class_registry[p] = plugin_class

        return plugin_class
コード例 #2
0
    def register(self, plugin_class):
        """Register plugin_class as provider for data types in provides.
        :param plugin_class: class inheriting from strax.Plugin.
        You can also pass a sequence of plugins to register, but then
        you must omit the provides argument.

        If a plugin class omits the .provides attribute, we will construct
        one from its class name (CamelCase -> snake_case)

        Returns plugin_class (so this can be used as a decorator)
        """
        if isinstance(plugin_class, (tuple, list)):
            # Shortcut for multiple registration
            for x in plugin_class:
                self.register(x)
            return

        if not hasattr(plugin_class, 'provides'):
            # No output name specified: construct one from the class name
            snake_name = strax.camel_to_snake(plugin_class.__name__)
            plugin_class.provides = (snake_name, )

        # Ensure plugin_class.provides is a tuple
        if isinstance(plugin_class.provides, str):
            plugin_class.provides = tuple([plugin_class.provides])

        for p in plugin_class.provides:
            self._plugin_class_registry[p] = plugin_class

        return plugin_class
コード例 #3
0
ファイル: test_cut_plugin.py プロジェクト: jmosbacher/strax
def test_cut_plugin(input_peaks, cut_threshold):
    """
    """
    # Just one chunk will do
    chunks = [input_peaks]
    _dtype = input_peaks.dtype

    class ToBeCut(strax.Plugin):
        """Data to be cut with strax.CutPlugin"""
        depends_on = tuple()
        dtype = _dtype
        provides = 'to_be_cut'
        data_kind = 'to_be_cut'  # match with depends_on below

        def compute(self, chunk_i):
            data = chunks[chunk_i]
            return self.chunk(data=data,
                              start=(int(data[0]['time']) if len(data) else
                                     np.arange(len(chunks))[chunk_i]),
                              end=(int(strax.endtime(data[-1])) if len(data)
                                   else np.arange(1,
                                                  len(chunks) + 1)[chunk_i]))

        # Hack to make peak output stop after a few chunks
        def is_ready(self, chunk_i):
            return chunk_i < len(chunks)

        def source_finished(self):
            return True

    class CutSomething(strax.CutPlugin):
        """Minimal working example of CutPlugin"""

        depends_on = ('to_be_cut', )

        def cut_by(self, to_be_cut):
            return to_be_cut[_dtype_name] > cut_threshold

    st = strax.Context(storage=[])
    st.register(ToBeCut)
    st.register(CutSomething)

    result = st.get_array(run_id='some_run',
                          targets=strax.camel_to_snake(CutSomething.__name__))
    correct_answer = np.sum(input_peaks[_dtype_name] > cut_threshold)
    assert len(result) == len(input_peaks), "WTF??"
    assert correct_answer == np.sum(result['cut_something']), (
        "Cut plugin does not give boolean arrays correctly")

    if len(input_peaks):
        assert strax.endtime(input_peaks).max() == \
               strax.endtime(result).max(), "last end time got scrambled"
        assert np.all(input_peaks['time'] ==
                      result['time']), "(start) times got scrambled"
        assert np.all(strax.endtime(input_peaks) == strax.endtime(
            result)), "Some end times got scrambled"
コード例 #4
0
    def __init__(self):
        super().__init__()

        _name = strax.camel_to_snake(self.__class__.__name__)
        if not hasattr(self, 'provides'):
            self.provides = _name
        if not hasattr(self, 'cut_name'):
            self.cut_name = _name
        if not hasattr(self, 'cut_description'):
            _description = _name
            if 'cut_' not in _description:
                _description = 'Cut by ' + _description
            else:
                _description = " ".join(_description.split("_"))
            self.cut_description = _description
コード例 #5
0
    def test_bad_configs_raising_errors(self):
        """Test that we get the right errors when we set invalid options"""
        dummy_st = self.st.new_context()
        dummy_st.set_config(
            {self.config_name: 'some_path_without_tf_protocol'})

        plugin = dummy_st.get_single_plugin(self.run_id, self.target)
        with self.assertRaises(ValueError):
            plugin.get_tf_model()

        dummy_st.set_config(
            {self.config_name: 'tf://some_path_that_does_not_exists'})

        plugin = dummy_st.get_single_plugin(self.run_id, self.target)
        with self.assertRaises(FileNotFoundError):
            plugin.get_tf_model()

        dummy_st.register(straxen.position_reconstruction.PeakPositionsBaseNT)
        plugin_name = strax.camel_to_snake('PeakPositionsBaseNT')
        with self.assertRaises(NotImplementedError):
            dummy_st.get_single_plugin(self.run_id, plugin_name)