예제 #1
0
class OnlyDiTauDataset_wo_mass(Dataset):
    def __init__(self, max_events, data_path, phase=None):
        from multiml.storegate import StoreGate
        from multiml.data.numpy import NumpyFlatData
        cur_seed = np.random.get_state()
        np.random.seed(1)
        permute = np.random.permutation(2 * max_events)
        np.random.set_state(cur_seed)

        self.storegate = StoreGate(backend='numpy', data_id='')
        self.jet_vals = [
            '1stRecoJetPt', '1stRecoJetEta', '1stRecoJetPhi', '1stRecoJetMass',
            '2ndRecoJetPt', '2ndRecoJetEta', '2ndRecoJetPhi', '2ndRecoJetMass'
        ]
        self.tau_vals = [
            '1stTruthTauJetPt', '1stTruthTauJetEta', '1stTruthTauJetPhi',
            '1stTruthTauJetMass', '2ndTruthTauJetPt', '2ndTruthTauJetEta',
            '2ndTruthTauJetPhi', '2ndTruthTauJetMass'
        ]
        self.tau_vals_wo_mass = [
            '1stTruthTauJetPt',
            '1stTruthTauJetEta',
            '1stTruthTauJetPhi',
            '2ndTruthTauJetPt',
            '2ndTruthTauJetEta',
            '2ndTruthTauJetPhi',
        ]
        self.istau_vals = ['tauFlag1stJet', 'tauFlag2ndJet']
        self.label_vals = ['label']
        self.energy_vals = ['1stRecoJetEnergyMap', '2ndRecoJetEnergyMap']
        for path, var_names in [
            ("jet.npy", self.jet_vals),
            ("tau.npy", self.tau_vals),
            ("istau.npy", self.istau_vals),
            ("energy.npy", self.energy_vals),
        ]:
            data_list = []
            for label in ['Htautau', 'Zpure_tau']:
                data_loaded = NumpyFlatData().load_file(data_path +
                                                        f"{label}_{path}")
                data_loaded = data_loaded[:max_events]
                data_list.append(data_loaded)
            data_loaded = np.concatenate(data_list)
            data_loaded = data_loaded[permute]

            self.storegate.update_data(data_id='',
                                       data=data_loaded,
                                       var_names=var_names,
                                       phase=(0.6, 0.2, 0.2))

        labels = np.concatenate([
            np.ones(max_events),
            np.zeros(max_events),
        ])[permute]

        self.storegate.update_data(data_id='',
                                   data=labels,
                                   var_names='label',
                                   phase=(0.6, 0.2, 0.2))
        self.storegate.compile()
        if phase is None:
            phase = 'train'
        self.phase = phase
        self._swich_phase()

    def _swich_phase(self):
        self.energy = np.transpose(
            self.storegate.get_data(
                self.energy_vals,
                self.phase,
                '',
            ), (0, 1, 4, 2, 3))
        self.jet_4vec = self.storegate.get_data(
            self.jet_vals,
            self.phase,
            '',
        )
        self.tau4vec_target = self.storegate.get_data(
            self.tau_vals_wo_mass,
            self.phase,
            '',
        ).astype(np.float32)
        self.label = self.storegate.get_data(
            self.label_vals,
            self.phase,
            '',
        ).astype(np.float32)

    def train(self):
        self.phase = 'train'
        self._swich_phase()

    def valid(self):
        self.phase = 'valid'
        self._swich_phase()

    def test(self):
        self.phase = 'test'
        self._swich_phase()

    def __len__(self):
        return self.storegate.get_metadata()['sizes'][self.phase]

    def __getitem__(self, idx):
        energy = self.energy[idx]
        jet_4vec = self.jet_4vec[idx]
        tau4vec_target = self.tau4vec_target[idx]
        label = self.label[idx]
        return {
            'inputs': [energy, jet_4vec],
            'targets': [tau4vec_target, label],
            'internal_vec': tau4vec_target,
            'internal_label': label
        }
예제 #2
0
def test_storegate_zarr():
    storegate = StoreGate(backend='hybrid',
                          backend_args={'mode': 'w'},
                          data_id=data_id)

    assert storegate.data_id == data_id
    storegate.set_data_id(data_id)
    assert storegate.data_id == data_id

    phase = (0.8, 0.1, 0.1)

    # add new variables (bind_vars=True)
    storegate.add_data(var_names=var_names01, data=data01_bind, phase=phase)

    assert storegate.get_data_ids() == [data_id]

    # change hybrid mode
    storegate.set_mode('numpy')

    # add new variables (bind_vars=False)
    storegate.add_data(var_names=var_names23,
                       data=data23,
                       phase=phase,
                       bind_vars=False)
    storegate.to_storage(var_names=var_names23, phase='all')
    storegate.to_memory(var_names=var_names23, phase='train')

    # update existing variables and data (bind_vars=True)
    storegate.update_data(var_names=var_names23, data=data23_bind, phase=phase)

    # update existing variables and data (bind_vars=False)
    storegate.update_data(var_names=var_names23,
                          data=data23,
                          phase=phase,
                          bind_vars=False)

    # compile for multiai
    storegate.compile()
    storegate.get_metadata()
    storegate.show_info()

    # update data by auto mode
    storegate.update_data(var_names=var_names23,
                          data=data23,
                          phase='all',
                          bind_vars=False)
    storegate['all'][var_names23][:] = data23_bind
    storegate.compile()

    # get data
    storegate.get_data(var_names=var_names23, phase='train', index=0)
    storegate.get_data(var_names=var_names23, phase='all')
    storegate['all'][var_names23][:]

    # tests
    total_events = len(data0)
    train_events = total_events * phase[0]
    valid_events = total_events * phase[1]
    test_events = total_events * phase[2]

    assert len(storegate['train']) == total_events * phase[0]
    assert len(storegate['valid']) == total_events * phase[1]
    assert len(storegate['test']) == total_events * phase[2]

    assert var_names23[0] in storegate['train']
    assert var_names23[1] in storegate['train']

    storegate_train = storegate['train']

    assert storegate_train[var_names23][:].shape == (train_events,
                                                     len(var_names23))
    assert storegate_train[var_names23][0].shape == (len(var_names23), )
    assert storegate_train[var_names23[0]][:].shape == (train_events, )
    assert storegate_train[var_names23[0]][0] == 20

    # delete data 
    storegate.delete_data(var_names='var2', phase='train')
    del storegate['train'][['var0', 'var1']]
    assert storegate.get_var_names(phase='train') == ['var3']

    # convert dtype
    storegate.astype(var_names='var3', dtype=np.int)
    storegate.onehot(var_names='var3', num_classes=40)
    storegate.argmax(var_names='var3', axis=1)
    assert storegate['train']['var3'][0] == 30


    # test shuffle with numpy mode
    storegate = StoreGate(backend='numpy', data_id=data_id)
    storegate.add_data(var_names=var_names01,
                       data=data01_bind,
                       phase=phase,
                       shuffle=True)
    storegate.add_data(var_names=var_names23,
                       data=data01_bind,
                       phase=phase,
                       shuffle=True)
    storegate.compile()

    storegate_train = storegate['train']
    assert storegate_train[var_names01][0][0] == storegate_train[var_names23][
        0][0]