コード例 #1
0
    def test_make_concatenate_and_shuffle_scripts(self):
        #main
        #repeat first 4 steps
        self.cfg = read_config(config_file)
        self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
        self.cfg, self.n_evts_total = update_cfg(self.cfg)

        self.cfg['n_evts_total'] = self.n_evts_total
        mds.print_input_statistics(self.cfg, self.ip_group_keys)
        for key in self.ip_group_keys:
            mds.add_fpaths_for_data_split_to_cfg(self.cfg, key)
        mds.make_dsplit_list_files(self.cfg)

        #create the bash job scripts and test their content
        mds.make_concatenate_and_shuffle_scripts(self.cfg)

        assert os.path.exists(concatenate_bash_script_train) == 1
        with open(concatenate_bash_script_train) as f:
            for line in f:
                pass  #yay, awesome style! ^^
            last_line = line
            self.assertIn(last_line, self.contents_concatenate_script)
        f.close

        assert os.path.exists(shuffle_bash_script_train) == 1
        with open(shuffle_bash_script_train) as f2:
            for line in f2:
                pass
            last_line = line
            self.assertIn(last_line, self.contents_shuffle_script)
        f2.close
コード例 #2
0
    def test_make_split(self):
        #main
        #repeat first 3 steps
        self.cfg = read_config(config_file)
        self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
        self.cfg, self.n_evts_total = update_cfg(self.cfg)

        self.cfg['n_evts_total'] = self.n_evts_total
        mds.print_input_statistics(self.cfg, self.ip_group_keys)
        for key in self.ip_group_keys:
            mds.add_fpaths_for_data_split_to_cfg(self.cfg, key)
        mds.make_dsplit_list_files(self.cfg)

        #assert the single output lists
        assert os.path.exists(list_output_val) == 1
        with open(list_output_val) as f:
            for line in f:
                self.assertIn(line, self.file_path_list_val)
        f.close

        assert os.path.exists(list_output_train) == 1
        with open(list_output_train) as f2:
            for line in f2:
                self.assertIn(line, self.file_path_list)
        f2.close
コード例 #3
0
    def test_get_filepath_and_n_events(self):
        #repeat first 2 steps
        self.cfg = read_config(config_file)
        self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)

        self.cfg, self.n_evts_total = update_cfg(self.cfg)

        for key in self.ip_group_keys:
            self.assertIn(self.cfg[key]['fpaths'][0], self.file_path_list)
            self.assertIn(self.cfg[key]['n_evts'], self.n_events_list)
コード例 #4
0
def update_cfg(cfg):
    ''' Update the cfg with file paths and also return the total number of events'''

    #get input groups and compare
    ip_group_keys = mds.get_all_ip_group_keys(cfg)
    os.chdir(test_data_dir)
    n_evts_total = 0
    for key in ip_group_keys:
        print('Collecting information from input group ' + key)
        cfg[key]['fpaths'] = mds.get_h5_filepaths(cfg[key]['dir'])
        cfg[key]['n_files'] = len(cfg[key]['fpaths'])
        cfg[key]['n_evts'], cfg[key]['n_evts_per_file_mean'], cfg[key][
            'run_ids'] = mds.get_number_of_evts_and_run_ids(cfg[key]['fpaths'],
                                                            dataset_key='y')
        n_evts_total += cfg[key]['n_evts']

    return cfg, n_evts_total
コード例 #5
0
def update_cfg(cfg):
    """ Update the cfg with file paths and also return the total number of events"""

    # get input groups and compare
    ip_group_keys = mds.get_all_ip_group_keys(cfg)
    os.chdir(test_data_dir)
    n_evts_total = 0
    for key in ip_group_keys:
        print("Collecting information from input group " + key)
        cfg[key]["fpaths"] = mds.get_h5_filepaths(cfg[key]["dir"])
        cfg[key]["n_files"] = len(cfg[key]["fpaths"])
        (
            cfg[key]["n_evts"],
            cfg[key]["n_evts_per_file_mean"],
            cfg[key]["run_ids"],
        ) = mds.get_number_of_evts_and_run_ids(cfg[key]["fpaths"],
                                               dataset_key="y")
        n_evts_total += cfg[key]["n_evts"]

    return cfg, n_evts_total
コード例 #6
0
 def test_read_keys_off_config(self):
     self.cfg = read_config(config_file)
     #get input groups and compare
     self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
     self.assertSequenceEqual(self.ip_group_keys,
                              self.input_categories_list)