예제 #1
0
    def setup(self):
        import tensorflow as tf
        keras = tf.keras

        nn_conf = get_resource(self.config['nn_architecture'], fmt='json')
        # badPMTList was inserted by a very clever person into the keras json
        # file. Let's delete it to prevent future keras versions from crashing.
        # Do NOT try `del nn_conf['badPMTList']`! See get_resource docstring
        # for the gruesome details.
        bad_pmts = nn_conf['badPMTList']
        nn = keras.models.model_from_json(
            json.dumps({k: v
                        for k, v in nn_conf.items() if k != 'badPMTList'}))
        self.pmt_mask = ~np.in1d(np.arange(self.config['n_top_pmts']),
                                 bad_pmts)

        # Keras needs a file to load its weights. We can't put the load
        # inside the context, then it would break on Windows,
        # because there temporary files cannot be opened again.
        with tempfile.NamedTemporaryFile(delete=False) as f:
            f.write(get_resource(self.config['nn_weights'], fmt='binary'))
            fname = f.name
        nn.load_weights(fname)
        os.remove(fname)
        self.nn = nn
예제 #2
0
    def setup(self):
        is_CMT = isinstance(self.config['s1_xyz_correction_map'], tuple)

        if is_CMT:
            cmt, cmt_conf, is_nt = self.config['s1_xyz_correction_map']
            cmt_conf = (
                f'{cmt_conf[0]}_{self.config["default_reconstruction_algorithm"]}',
                cmt_conf[1])
            map_algo = cmt, cmt_conf, is_nt

            self.s1_map = InterpolatingMap(
                get_resource(get_config_from_cmt(self.run_id, map_algo)))
        else:
            self.s1_map = InterpolatingMap(
                get_resource(self.config['s1_xyz_correction_map']))

        self.s2_map = InterpolatingMap(
            get_resource(
                get_config_from_cmt(self.run_id,
                                    self.config['s2_xy_correction_map'])))
        self.elife = get_correction_from_cmt(self.run_id,
                                             self.config['elife_conf'])

        if isinstance(self.elife, str):
            # Legacy 1T support
            self.elife = get_elife(self.run_id, self.elife)
예제 #3
0
    def setup(self):

        self.electron_drift_velocity = get_correction_from_cmt(
            self.run_id, self.config['electron_drift_velocity'])
        self.electron_drift_time_gate = get_correction_from_cmt(
            self.run_id, self.config['electron_drift_time_gate'])

        if isinstance(self.config['fdc_map'], str):
            self.map = InterpolatingMap(
                get_resource(self.config['fdc_map'], fmt='binary'))

        elif is_cmt_option(self.config['fdc_map']):
            self.map = InterpolatingMap(
                get_cmt_resource(
                    self.run_id,
                    tuple([
                        'suffix',
                        self.config['default_reconstruction_algorithm'],
                        *self.config['fdc_map']
                    ]),
                    fmt='binary'))
            self.map.scale_coordinates([1., 1., -self.electron_drift_velocity])

        else:
            raise NotImplementedError('FDC map format not understood.')
예제 #4
0
    def setup(self):
        c = self.config
        c.update(get_resource(c['fax_config'], fmt='json'))
        # Update gains to the nT defaults
        self.to_pe = get_to_pe(self.run_id, c['gain_model'],
                               c['channel_map']['tpc'][1] + 1)
        c['gains'] = 1 / self.to_pe * (1e-8 * 2.25 / 2**14) / (1.6e-19 * 10 *
                                                               50)
        c['gains'][self.to_pe == 0] = 0
        if c['seed'] != False:
            np.random.seed(c['seed'])

        overrides = self.config['fax_config_override']
        if overrides is not None:
            c.update(overrides)

        #We hash the config to load resources. Channel map is immutable and cannot be hashed
        self.config['channel_map'] = dict(self.config['channel_map'])
        self.config['channel_map']['sum_signal'] = 800
        self.config['channels_bottom'] = np.arange(self.config['n_top_pmts'],
                                                   self.config['n_tpc_pmts'])

        self.get_instructions()
        self.check_instructions()
        self._setup()
예제 #5
0
파일: plugins.py 프로젝트: XeBoris/straxen
    def setup(self):
        import keras
        import tensorflow as tf

        self.pmt_mask = to_pe[:self.n_top_pmts] > 0

        nn = keras.models.model_from_json(
            get_resource(self.config['nn_architecture']))
        temp_f = '_temp.h5'
        with open(temp_f, mode='wb') as f:
            f.write(get_resource(self.config['nn_weights'], binary=True))
        nn.load_weights(temp_f)
        self.nn = nn

        # Workaround for using keras/tensorflow in a threaded environment. See:
        # https://github.com/keras-team/keras/issues/5640#issuecomment-345613052
        self.nn._make_predict_function()
        self.graph = tf.get_default_graph()
예제 #6
0
    def setup(self):
        import keras
        import tensorflow as tf
        import tempfile

        self.to_pe = get_to_pe(self.run_id, self.config['to_pe_file'])

        nn_json = get_resource(self.config['nn_architecture'])
        nn = keras.models.model_from_json(nn_json)

        bad_pmts = json.loads(nn_json)['badPMTList']
        self.pmt_mask = ~np.in1d(np.arange(self.n_top_pmts), bad_pmts)

        with tempfile.NamedTemporaryFile() as f:
            f.write(get_resource(self.config['nn_weights'], fmt='binary'))
            nn.load_weights(f.name)
        self.nn = nn

        # Workaround for using keras/tensorflow in a threaded environment. See:
        # https://github.com/keras-team/keras/issues/5640#issuecomment-345613052
        self.nn._make_predict_function()
        self.graph = tf.get_default_graph()
예제 #7
0
    def setup(self):

        is_CMT = isinstance(self.config['fdc_map'], tuple)

        if is_CMT:

            cmt, cmt_conf, is_nt = self.config['fdc_map']
            cmt_conf = (
                f'{cmt_conf[0]}_{self.config["default_reconstruction_algorithm"]}',
                cmt_conf[1])
            map_algo = cmt, cmt_conf, is_nt

            self.map = InterpolatingMap(
                get_resource(get_config_from_cmt(self.run_id, map_algo),
                             fmt='binary'))
            self.map.scale_coordinates(
                [1., 1., -self.config['electron_drift_velocity']])

        elif isinstance(self.config['fdc_map'], str):
            self.map = InterpolatingMap(
                get_resource(self.config['fdc_map'], fmt='binary'))

        else:
            raise NotImplementedError('FDC map format not understood.')
예제 #8
0
    def setup(self):
        import tensorflow as tf
        self.has_tf2 = parse_version(tf.__version__) > parse_version('2.0.a')
        if self.has_tf2:
            keras = tf.keras
        else:
            import keras

        nn_conf = get_resource(self.config['nn_architecture'], fmt='json')
        # badPMTList was inserted by a very clever person into the keras json
        # file. Let's delete it to prevent future keras versions from crashing.
        # Do NOT try `del nn_conf['badPMTList']`! See get_resource docstring
        # for the gruesome details.
        bad_pmts = nn_conf['badPMTList']
        nn = keras.models.model_from_json(
            json.dumps({k: v
                        for k, v in nn_conf.items() if k != 'badPMTList'}))
        self.pmt_mask = ~np.in1d(np.arange(self.config['n_top_pmts']),
                                 bad_pmts)

        # Keras needs a file to load its weights. We can't put the load
        # inside the context, then it would break on windows
        # because there temporary files cannot be opened again.
        with tempfile.NamedTemporaryFile(delete=False) as f:
            f.write(get_resource(self.config['nn_weights'], fmt='binary'))
            fname = f.name
        nn.load_weights(fname)
        os.remove(fname)
        self.nn = nn

        if not self.has_tf2:
            # Workaround for using keras/tensorflow in a threaded environment.
            # See: https://github.com/keras-team/keras/issues/
            # 5640#issuecomment-345613052
            self.nn._make_predict_function()
            self.graph = tf.get_default_graph()
예제 #9
0
    def setup(self):
        c = self.config
        c.update(get_resource(c['fax_config'], fmt='json'))
        # Update gains to the nT defaults
        self.to_pe = get_to_pe(self.run_id, c['gain_model'],
                               len(c['channels_in_detector']['tpc']))
        c['gains'] = 1 / self.to_pe * (1e-8 * 2.25 / 2**14) / (1.6e-19 * 10 *
                                                               50)
        c['gains'][self.to_pe == 0] = 0
        if c['seed'] != False:
            np.random.seed(c['seed'])

        overrides = self.config['fax_config_override']
        if overrides is not None:
            c.update(overrides)

        if c['optical']:
            self.instructions, self.channels, self.timings = read_optical(
                c['fax_file'])
            c['nevents'] = len(self.instructions['event_number'])

        elif c['fax_file']:
            assert c['fax_file'][
                -5:] != '.root', 'None optical g4 input is deprecated use EPIX instead'
            self.instructions = instruction_from_csv(c['fax_file'])
            c['nevents'] = np.max(self.instructions['event_number'])

        else:
            self.instructions = rand_instructions(c)

        # Let below cathode S1 instructions pass but remove S2 instructions
        m = (self.instructions['z'] <
             -c['tpc_length']) & (self.instructions['type'] == 2)
        self.instructions = self.instructions[~m]

        assert np.all(self.instructions['x']**2 + self.instructions['y']**2 < c['tpc_radius']**2), \
                "Interation is outside the TPC"
        assert np.all(self.instructions['z'] < 0.25), \
                "Interation is outside the TPC"
        assert np.all(self.instructions['amp'] > 0), \
                "Interaction has zero size"
예제 #10
0
    def setup(self):
        c = self.config
        c.update(get_resource(c['fax_config'], fmt='json'))

        if c['fax_file']:
            if c['fax_file'][-5:] == '.root':
                self.instructions = read_g4(c['fax_file'])
                c['nevents'] = np.max(self.instructions['event_number'])
            else:
                self.instructions = instruction_from_csv(c['fax_file'])
                c['nevents'] = np.max(self.instructions['event_number'])

        else:
            self.instructions = rand_instructions(c)

        assert np.all(
            self.instructions['x']**2 +
            self.instructions['y']**2 < 2500), "Interation is outside the TPC"
        assert np.all(self.instructions['z'] < 0) & np.all(
            self.instructions['z'] > -100), "Interation is outside the TPC"
        assert np.all(
            self.instructions['amp'] > 0), "Interaction has zero size"
예제 #11
0
    def __init__(self, config={}):
        self.config = default_config
        self.config.update(get_resource(self.config['fax_config'], fmt='json'))
        self.config.update(config)

        if self.config['fax_file']:
            if self.config['fax_file'][-5:] == '.root':
                self.instructions = read_g4(self.config['fax_file'])
                self.config['nevents'] = np.max(
                    self.instructions['event_number'])
            else:
                self.instructions = instruction_from_csv(
                    self.config['fax_file'])
                self.config['nevents'] = np.max(
                    self.instructions['event_number'])

        else:
            self.instructions = rand_instructions(self.config)

        self.pax_event = PaxEvents(self.config)
        self.transfer_plugin = self.WriteZippedEncoder(self.config)
        self.output_plugin = self.WriteZipped(self.config)
예제 #12
0
 def setup(self):
     self.s1_map = InterpolatingMap(
         get_resource(self.config['s1_relative_lce_map']))
     self.s2_map = InterpolatingMap(
         get_resource(self.config['s2_relative_lce_map']))
     self.elife = get_elife(self.run_id,self.config['elife_file'])
예제 #13
0
 def setup(self):
     self.map = InterpolatingMap(
         get_resource(self.config['fdc_map'], fmt='binary'))
예제 #14
0
파일: plugins.py 프로젝트: XeBoris/straxen
 def setup(self):
     self.s1_map = InterpolatingMap(
         get_resource(self.config['s1_relative_lce_map']))
     self.s2_map = InterpolatingMap(
         get_resource(self.config['s2_relative_lce_map']))
예제 #15
0
파일: plugins.py 프로젝트: XeBoris/straxen
 def setup(self):
     self.map = InterpolatingMap(
         get_resource(self.config['fdc_map'], binary=True))