예제 #1
0
파일: rootcache.py 프로젝트: claria/Artus
    def _get_cached(self, *args, **kwargs):
        cache_file = self._determine_cache_file(*args, **kwargs)
        root_tree = None
        root_object = None
        cache_found = False
        if kwargs.get("use_cache", True) and os.path.exists(cache_file):
            with tfilecontextmanager.TFileContextManager(cache_file,
                                                         "READ") as root_file:
                root_object = root_file.Get(self.cache_name)
                root_object.SetDirectory(0)
                if (not root_object is None) and (not root_object == None):
                    cache_found = True
                    log.debug(
                        "Took cached object from \"{root_file}/{path_to_object}\"."
                        .format(root_file=cache_file,
                                path_to_object=self.cache_name))

        if root_object is None:
            root_tree, root_object = self._function_to_cache(*args, **kwargs)
            if (not cache_found) and (not root_object is None) and (
                    not root_object == None):
                with tfilecontextmanager.TFileContextManager(
                        cache_file, "RECREATE") as root_file:
                    root_file.cd()
                    root_object.Write(self.cache_name,
                                      ROOT.TObject.kWriteDelete)
                    log.debug(
                        "Created cache in \"{root_file}/{path_to_object}\".".
                        format(root_file=cache_file,
                               path_to_object=self.cache_name))
                    root_object.SetDirectory(0)

        return root_tree, root_object
예제 #2
0
    def all_runtimes(root_filename, plot_config_template=None):
        if plot_config_template is None:
            plot_config_template = {}

        plot_configs = []
        with tfilecontextmanager.TFileContextManager(root_filename,
                                                     "READ") as root_file:
            elements = roottools.RootTools.walk_root_directory(root_file)
            for key, path in elements:
                if path.endswith("runTime") and (
                        key.GetClassName().startswith("TTree")
                        or key.GetClassName().startswith("TNtuple")):
                    tree = root_file.Get(path)
                    for branch in [
                            branch.GetName()
                            for branch in tree.GetListOfBranches()
                    ]:
                        plot_config = copy.deepcopy(plot_config_template)
                        plot_config["files"] = root_filename
                        plot_config["folders"] = [path]
                        plot_config["x_expressions"] = [branch]
                        plot_config["weights"] = ["({0}>=0.0)".format(branch)]
                        plot_config["title"] = branch
                        plot_config["x_label"] = "Runtime per Event / #mus"
                        plot_config["output_dir"] = os.path.join(
                            plot_config.get("output_dir", ""),
                            os.path.dirname(path))
                        plot_config["filename"] = branch
                        plot_configs.append(plot_config)
        return plot_configs
예제 #3
0
파일: csv2root.py 프로젝트: whahmad/Artus
def csv2root(args):
	csv_filename = args[0]
	variable_list = args[1]
	root_filename = os.path.splitext(csv_filename)[0]+".root"
	with tfilecontextmanager.TFileContextManager(root_filename, "RECREATE") as root_file:
		tree = ROOT.TTree("tree", csv_filename)
		tree.ReadFile(csv_filename, variable_list)
		tree.Write(tree.GetName())
	log.info("Converted {csv} to {root}:tree.".format(csv=csv_filename, root=root_filename))
예제 #4
0
def treemerge(input_file_names, input_tree_names,
              output_file_name, output_tree_name,
              match_input_tree_names = False):
	
	input_tree = ROOT.TChain()
	for input_file_name, input_tree_name in zip(input_file_names, input_tree_names):
		with tfilecontextmanager.TFileContextManager(input_file_name, "READ") as root_file:
			tree_names = [path for key, path in roottools.RootTools.walk_root_directory(root_file) if key.GetClassName().startswith("TTree")]
		selected_tree_names = [tree_name for tree_name in tree_names if re.search(input_tree_name, tree_name)]
		for name in selected_tree_names:
			input_tree.Add(os.path.join(input_file_name, name))
	output_file = ROOT.TFile(output_file_name, "RECREATE")
	print input_tree # do not remove, does not work without it
	output_tree = input_tree.CloneTree()
	output_tree.SetName(output_tree_name)
	output_file.Write()
	output_file.Close()
	return os.path.join(output_file_name, output_tree_name)
예제 #5
0
	def all_histograms(root_filename, plot_config_template=None):
		if plot_config_template is None:
			plot_config_template = {}
		
		plot_configs = []
		with tfilecontextmanager.TFileContextManager(root_filename, "READ") as root_file:
			elements = roottools.RootTools.walk_root_directory(root_file)
			for key, path in elements:
				if key.GetClassName().startswith("TH") or key.GetClassName().startswith("TProfile"):
					plot_config = copy.deepcopy(plot_config_template)
					plot_config["files"] = [root_filename]
					plot_config["folders"] = [os.path.dirname(path)]
					plot_config["x_expressions"] = [os.path.basename(path)]
					plot_config["title"] = path
					plot_config["x_label"] = ""
					plot_config["output_dir"] = os.path.join(os.path.splitext(root_filename)[0], os.path.dirname(path))
					plot_config["filename"] = os.path.basename(path)
					plot_configs.append(plot_config)
		return plot_configs
예제 #6
0
def main():

    parser = argparse.ArgumentParser(
        description="Add friend to tree in file with constant values.",
        parents=[logger.loggingParser])

    parser.add_argument(
        "files",
        nargs="+",
        help="Files containing the tree. The files will be updated.")
    parser.add_argument("-t",
                        "--tree",
                        required=True,
                        help="Path to the tree object in the file.")
    parser.add_argument(
        "--values",
        nargs="+",
        required=True,
        help=
        "Values to add to the friend tree. Each value gets a separate branch.")
    parser.add_argument("-b",
                        "--branches",
                        nargs="+",
                        default=[None],
                        help="Branch names.")

    args = parser.parse_args()
    logger.initLogger(args)

    if len(args.branches) < len(args.values):
        args.branches = (args.branches +
                         (len(args.values) * [None]))[:len(args.values)]
    args.branches = [
        "annotation%d" % i if b is None else b
        for i, b in enumerate(args.branches)
    ]

    types = {
        bool: "O",
        int: "I",
        float: "D",
        str: "C",
    }
    value_types = []
    for index, value in enumerate(args.values):
        try:
            value = int(value)
            value_types.append(int)
        except:
            try:
                value = float(value)
                value_types.append(float)
            except:
                try:
                    # TODO: check if type is bool
                    #value = bool(value)
                    #value_types.append(bool)
                    raise Exception()
                except:
                    # TODO: find way to branch strings
                    #value_types.append(str)
                    raise Exception()
        args.values[index] = value

    log.info("New branches:")
    for index, (branch, value_type, value) in enumerate(
            zip(args.branches, value_types, args.values)):
        log.info("\t%s/%s = %s" % (branch, types[value_type], str(value)))

    for file_name in progressiterator.ProgressIterator(
            args.files, description="Processing files"):
        with tfilecontextmanager.TFileContextManager(file_name,
                                                     "UPDATE") as root_file:
            tree = root_file.Get(args.tree)

            dir_name = os.path.dirname(args.tree)
            if not dir_name == "":
                root_file.Get(dir_name)

            elements = zip(
                *roottools.RootTools.walk_root_directory(root_file))[-1]
            friend_tree_name = None
            n_trials = 0
            while friend_tree_name is None:
                tmp_friend_tree_name = (tree.GetName() + "_friend_" +
                                        str(n_trials)).rstrip("_0")
                if not tmp_friend_tree_name in elements:
                    friend_tree_name = tmp_friend_tree_name
                n_trials += 1
            friend_tree = ROOT.TTree(friend_tree_name,
                                     tree.GetTitle() + " (friend)")

            values = []
            for branch, value_type, value in zip(args.branches, value_types,
                                                 args.values):
                values.append(numpy.zeros(1, dtype=value_type))
                values[-1][0] = value
                friend_tree.Branch(branch, values[-1],
                                   "%s/%s" % (branch, types[value_type]))

            for entry in xrange(tree.GetEntries()):
                friend_tree.Fill()

            friend_tree.AddFriend(tree, tree.GetName())
            tree.AddFriend(friend_tree, friend_tree.GetName())

            root_file.Write()
예제 #7
0
    def all_runtimes(root_filename, plot_config_template=None):
        if plot_config_template is None:
            plot_config_template = {}

        plot_configs = []
        with tfilecontextmanager.TFileContextManager(root_filename,
                                                     "READ") as root_file:
            elements = roottools.RootTools.walk_root_directory(root_file)
            for key, path in elements:
                if path.endswith("runTime") and (
                        key.GetClassName().startswith("TTree")
                        or key.GetClassName().startswith("TNtuple")):
                    tree = root_file.Get(path)
                    branches = [
                        branch.GetName()
                        for branch in tree.GetListOfBranches()
                    ]

                    plot_config = copy.deepcopy(plot_config_template)
                    plot_config["files"] = [root_filename]
                    plot_config["folders"] = [path]
                    plot_config["nicks"] = ["nick"]
                    plot_config["x_expressions"] = [
                        str(x) for x in range(len(branches))
                    ]
                    plot_config["x_bins"] = [
                        str(len(branches)) + ",-0.5," +
                        str(len(branches) - 0.5)
                    ]
                    plot_config["y_expressions"] = branches
                    plot_config["weights"] = [
                        "({0}>=0.0)".format(branch) for branch in branches
                    ]
                    plot_config["tree_draw_options"] = ["profs"]
                    plot_config["x_label"] = ""
                    plot_config["x_ticks"] = range(len(branches))
                    plot_config["x_tick_labels"] = branches
                    plot_config["bottom_pad_margin"] = 0.5
                    plot_config["y_label"] = "Runtime per Event / #mus"
                    plot_config["y_log"] = True
                    plot_config["y_lims"] = [1e-1]
                    plot_config["title"] = os.path.dirname(path)
                    plot_config["canvas_width"] = 1800
                    plot_config["output_dir"] = os.path.join(
                        plot_config.get("output_dir", ""),
                        os.path.dirname(path))
                    plot_config["filename"] = "runtimes_per_event"
                    if "www" in plot_config:
                        plot_config["www"] = os.path.join(
                            plot_config.get("www", ""), os.path.dirname(path))
                    plot_configs.append(plot_config)

                    plot_config = copy.deepcopy(plot_config_template)
                    plot_config["files"] = [root_filename]
                    plot_config["folders"] = [path]
                    plot_config["nicks"] = ["nick"]
                    plot_config["x_expressions"] = [
                        str(x) for x in range(len(branches))
                    ]
                    plot_config["x_bins"] = [
                        str(len(branches)) + ",-0.5," +
                        str(len(branches) - 0.5)
                    ]
                    plot_config["weights"] = [
                        "({0}>=0.0)*{0}".format(branch) for branch in branches
                    ]
                    plot_config["x_label"] = ""
                    plot_config["x_ticks"] = range(len(branches))
                    plot_config["x_tick_labels"] = branches
                    plot_config["bottom_pad_margin"] = 0.5
                    plot_config["y_label"] = "Total Runtime / #mus"
                    plot_config["y_log"] = True
                    plot_config["title"] = os.path.dirname(path)
                    plot_config["canvas_width"] = 1800
                    plot_config["output_dir"] = os.path.join(
                        plot_config.get("output_dir", ""),
                        os.path.dirname(path))
                    plot_config["filename"] = "total_runtimes"
                    if "www" in plot_config:
                        plot_config["www"] = os.path.join(
                            plot_config.get("www", ""), os.path.dirname(path))
                    plot_configs.append(plot_config)

                    for branch in branches:
                        plot_config = copy.deepcopy(plot_config_template)
                        plot_config["files"] = [root_filename]
                        plot_config["folders"] = [path]
                        plot_config["x_expressions"] = [branch]
                        plot_config["weights"] = ["({0}>=0.0)".format(branch)]
                        plot_config["title"] = branch
                        plot_config["x_label"] = "Runtime per Event / #mus"
                        plot_config["output_dir"] = os.path.join(
                            plot_config.get("output_dir", ""),
                            os.path.dirname(path))
                        plot_config["filename"] = branch
                        if "www" in plot_config:
                            plot_config["www"] = os.path.join(
                                plot_config.get("www", ""),
                                os.path.dirname(path))
                        plot_configs.append(plot_config)
        return plot_configs
예제 #8
0
    args = parser.parse_args()
    logger.initLogger(args)

    plot_configs = []
    for root_filename in args.root_files:
        config = {
            "files": [root_filename],
            "output_dir":
            os.path.dirname(root_filename),
            "filename":
            os.path.basename(root_filename).replace(
                ".root", "" if args.in_place else "_renamed"),
            "plot_modules": ["ExportRoot"],
        }
        with tfilecontextmanager.TFileContextManager(root_filename,
                                                     "READ") as root_file:
            elements = roottools.RootTools.walk_root_directory(root_file)
            for index, (key, path) in enumerate(elements):
                class_name = key.GetClassName()
                if any([
                        class_name.startswith(allowed_type)
                        for allowed_type in ["TH", "TProfile", "TGraph", "TF"]
                ]):

                    label = copy.deepcopy(path)
                    for search_regex, replacement in zip(
                            args.search_regexs, args.replacements):
                        label = re.sub(search_regex, replacement, label)

                    if (not args.in_place) or (label != path):
                        config.setdefault("x_expressions", []).append(path)
        pprint.pprint(plot_configs)
    higgsplot.HiggsPlotter(list_of_config_dicts=tmp_plot_configs,
                           list_of_args_strings=[""],
                           n_processes=args.n_processes,
                           n_plots=args.n_plots)

    # produce combined plots
    plot_config = {}
    tmp_input_files = [
        os.path.join(tmp_plot_config["output_dir"],
                     tmp_plot_config["filename"] + ".root")
        for tmp_plot_config in tmp_plot_configs
    ]

    root_object_paths = []
    with tfilecontextmanager.TFileContextManager(tmp_input_files[0],
                                                 "READ") as root_file:
        root_object_paths = zip(
            *roottools.RootTools.walk_root_directory(root_file))[-1]
    plot_indices = range(len(root_object_paths))
    if not args.nicks is None:
        root_object_paths, plot_indices = zip(*[(nick,
                                                 root_object_paths.index(nick))
                                                for nick in args.nicks
                                                if nick in root_object_paths])

    plot_config["files"] = tmp_input_files * len(root_object_paths)
    plot_config["x_expressions"] = list(
        itertools.chain(*[[root_object_path] * len(tmp_input_files)
                          for root_object_path in root_object_paths]))
    plot_config["nicks"] = [
        "_".join(str(part) for part in nick_parts)
                        help="Histogram name. [Default: %(default)s]")
    parser.add_argument(
        "-o",
        "--output",
        default=
        "$CMSSW_BASE/src/HiggsAnalysis/KITHiggsToTauTau/data/root/identificationWeights/isolationWeights_Data2015D_Spring15DR74Simulation_muon_MediumID_TightRelIso.root",
        help="Output ROOT file. [Default: %(default)s]")

    args = parser.parse_args()
    logger.initLogger(args)

    input_file = os.path.join(args.input_dir,
                              "MuonIso_Z_RunD_Reco74X_Nov20.root")
    input_path = "NUM_TightRelIso_DEN_MediumID_PAR_pt_spliteta_bin1/pt_abseta_ratio"

    with tfilecontextmanager.TFileContextManager(input_file,
                                                 "READ") as root_file:
        elements = roottools.RootTools.walk_root_directory(root_file)
        for key, path in elements:
            if path == input_path:
                temp_th2f = root_file.Get(path).Clone()

                root_file = ROOT.TFile(args.output, "RECREATE")
                pt_bins = array.array(
                    "d", [20.0, 25.0, 30.0, 40.0, 50.0, 60.0, 1000.0])
                eta_bins = array.array(
                    "d", [-2.4, -2.1, -1.2, -0.9, 0.0, 0.9, 1.2, 2.1, 2.4])
                histogram = ROOT.TH2F(args.histogram_name, args.histogram_name,
                                      len(pt_bins) - 1, pt_bins,
                                      len(eta_bins) - 1, eta_bins)

                # https://twiki.cern.ch/twiki/bin/viewauth/CMS/MuonReferenceEffsRun2
예제 #11
0
def main():

    parser = argparse.ArgumentParser(
        description=
        "Check the validity of Artus outputs in a project directory.",
        parents=[logger.loggingParser])

    parser.add_argument("project_directory", help="Artus project directory")
    parser.add_argument("--dry-run",
                        help="Only print problematic job numbers",
                        default=False,
                        action="store_true")
    parser.add_argument("--no-resubmission",
                        help="Only print and invalidate problematic jobs",
                        default=False,
                        action="store_true")

    args = parser.parse_args()
    logger.initLogger(args)

    # GC settings
    artus_gc_config = os.path.join(args.project_directory,
                                   "grid-control_config.conf")

    n_artus_jobs = 0
    with gzip.open(
            os.path.join(args.project_directory, "workdir/params.map.gz"),
            "rb") as gc_params_map:
        n_artus_jobs = int(gc_params_map.read().strip().rstrip())

    # locate artus outputs
    artus_outputs = {}

    artus_root_outputs = glob.glob(
        os.path.join(args.project_directory, "output/*/*_job_*_output.root"))
    for artus_root_output in progressiterator.ProgressIterator(
            artus_root_outputs, description="Locating Artus ROOT output"):
        artus_outputs.setdefault(
            int(
                re.search(".*_job_(?P<job_number>\d+)_output.root",
                          artus_root_output).groupdict().get("job_number")),
            {})["root"] = artus_root_output

    artus_log_outputs = glob.glob(
        os.path.join(args.project_directory, "output/*/*_job_*_log.log"))
    for artus_log_output in progressiterator.ProgressIterator(
            artus_log_outputs, description="Locating Artus log output"):
        artus_outputs.setdefault(
            int(
                re.search(".*_job_(?P<job_number>\d+)_log.log",
                          artus_log_output).groupdict().get("job_number")),
            {})["log"] = artus_log_output

    failed_job_numbers = []

    # check existance of all files
    for job_number in progressiterator.ProgressIterator(
            xrange(n_artus_jobs),
            description="Check existance of all Artus outputs"):
        if ((artus_outputs.get(job_number) is None)
                or (artus_outputs.get(job_number).get("root") is None)
                or (artus_outputs.get(job_number).get("log") is None)):

            failed_job_numbers.append(str(job_number))

    # check validity of ROOT files
    for job_number, outputs in progressiterator.ProgressIterator(
            artus_outputs.items(),
            description="Check validity of Artus ROOT outputs"):
        with tfilecontextmanager.TFileContextManager(outputs["root"],
                                                     "READ") as root_file:
            # https://root.cern.ch/root/roottalk/roottalk02/4340.html
            if root_file.IsZombie() or root_file.TestBit(
                    ROOT.TFile.kRecovered):
                failed_job_numbers.append(str(job_number))
            else:
                elements = roottools.RootTools.walk_root_directory(root_file)
                if len(elements) <= 1:
                    failed_job_numbers.append(str(job_number))

    if len(failed_job_numbers) == 0:
        log.info("No problematic Artus outputs found.")
    else:
        gc_reset_command = "go.py --reset id:" + (
            ",".join(failed_job_numbers)) + " " + artus_gc_config
        log.info(gc_reset_command)

        if not args.dry_run:
            logger.subprocessCall(shlex.split(gc_reset_command))

            if not args.no_resubmission:
                gc_run_command = "go.py " + artus_gc_config
                log.info(gc_run_command)
                logger.subprocessCall(shlex.split(gc_run_command))
	def extract_shapes(self, root_filename_template,
	                   bkg_histogram_name_template, sig_histogram_name_template,
	                   bkg_syst_histogram_name_template, sig_syst_histogram_name_template,
	                   update_systematics=False):
		for analysis in self.cb.analysis_set():
			for era in self.cb.cp().analysis([analysis]).era_set():
				for channel in self.cb.cp().analysis([analysis]).era([era]).channel_set():
					for category in self.cb.cp().analysis([analysis]).era([era]).channel([channel]).bin_set():
						root_filename = root_filename_template.format(
								ANALYSIS=analysis,
								CHANNEL=channel,
								BIN=category,
								ERA=era
						)

						cb_backgrounds = self.cb.cp().analysis([analysis]).era([era]).channel([channel]).bin([category]).backgrounds()
						cb_backgrounds.ExtractShapes(
								root_filename,
								bkg_histogram_name_template.replace("{", "").replace("}", ""),
								bkg_syst_histogram_name_template.replace("{", "").replace("}", "")
						)

						cb_signals = self.cb.cp().analysis([analysis]).era([era]).channel([channel]).bin([category]).signals()
						cb_signals.ExtractShapes(
								root_filename,
								sig_histogram_name_template.replace("{", "").replace("}", ""),
								sig_syst_histogram_name_template.replace("{", "").replace("}", "")
						)

						# update/add systematics related to the estimation of backgrounds/signals
						# these uncertainties are stored in the input root files
						if update_systematics:
							with tfilecontextmanager.TFileContextManager(root_filename, "READ") as root_file:
								root_object_paths = [path for key, path in roottools.RootTools.walk_root_directory(root_file)]

								processes_histogram_names = []
								for process in cb_backgrounds.process_set():
									bkg_histogram_name = bkg_histogram_name_template.replace("$", "").format(
											ANALYSIS=analysis,
											CHANNEL=channel,
											BIN=category,
											ERA=era,
											PROCESS=process
									)
									yield_unc_rel = Datacards.get_yield_unc_rel(bkg_histogram_name, root_file, root_object_paths)
									if (not yield_unc_rel is None) and (yield_unc_rel != 0.0):
										cb_backgrounds.cp().process([process]).AddSyst(
												self.cb,
												"CMS_$ANALYSIS_$PROCESS_estimation_$ERA",
												"lnN",
												ch.SystMap("process")([process], 1.0+yield_unc_rel)
										)

								for process in cb_signals.process_set():
									for mass in cb_signals.mass_set():
										if mass != "*":
											sig_histogram_name = sig_histogram_name_template.replace("$", "").format(
													ANALYSIS=analysis,
													CHANNEL=channel,
													BIN=category,
													ERA=era,
													PROCESS=process,
													MASS=mass
											)
											yield_unc_rel = Datacards.get_yield_unc_rel(sig_histogram_name, root_file, root_object_paths)
											if (not yield_unc_rel is None) and (yield_unc_rel != 0.0):
												cb_backgrounds.cp().process([process]).mass([mass]).AddSyst(
														self.cb,
														"CMS_$ANALYSIS_$PROCESS$MASS_estimation_$ERA",
														"lnN",
														ch.SystMap("process", "mass")([process], [mass], 1.0+yield_unc_rel)
												)

		if log.isEnabledFor(logging.DEBUG):
			self.cb.PrintAll()
            args.lumi = samples.default_lumi / 1000.0
    else:
        log.critical("Invalid era string selected: " + args.era)
        sys.exit(1)

    args.signal_samples = [sample.split() for sample in args.signal_samples]
    args.background_samples = [
        sample.split() for sample in args.background_samples
    ]

    inputs_base = tools.longest_common_substring_from_list(
        [os.path.dirname(input_file) for input_file in args.input_files])

    input_file_content = {}
    for input_filename in args.input_files:
        with tfilecontextmanager.TFileContextManager(input_filename,
                                                     "READ") as input_file:
            for histogram in list(
                    zip(*roottools.RootTools.walk_root_directory(input_file))
                [-1]):
                input_file_content.setdefault(input_filename, {}).setdefault(
                    os.path.dirname(histogram),
                    []).append(os.path.basename(histogram))

    plot_configs = []
    for input_file, input_histograms in input_file_content.iteritems():
        short_directory = os.path.dirname(input_file).replace(inputs_base, "")

        for folder, histograms in input_histograms.iteritems():
            channel = folder[:folder.find("_")]
            category = folder[folder.find("_") + 1:].replace(
                "_prefit", "").replace("_postfit", "")
예제 #14
0
파일: sort-trees.py 프로젝트: whahmad/Artus
def main():

    parser = argparse.ArgumentParser(description="Sort trees.",
                                     parents=[logger.loggingParser])

    parser.add_argument("inputs",
                        nargs="+",
                        help="Input files containing the tree to sort.")
    parser.add_argument("-t",
                        "--tree",
                        required=True,
                        help="Path to the tree object in the file.")
    parser.add_argument("-b",
                        "--branches",
                        nargs="+",
                        default=["run", "lumi", "event"],
                        help="Branch names to be considered for the sorting.")
    parser.add_argument("-o",
                        "--output",
                        default="output.root",
                        help="Output ROOT file.")

    args = parser.parse_args()
    logger.initLogger(args)

    args.branches = args.branches[:4]

    # https://root.cern.ch/root/roottalk/roottalk01/3646.html

    log.info("Opening input from")
    input_tree = ROOT.TChain()
    for input_file in args.inputs:
        path = os.path.join(input_file, args.tree)
        log.info("\t" + path)
        input_tree.Add(path)
    input_tree.SetCacheSize(128 * 1024 * 1024)
    n_entries = input_tree.GetEntries()

    values = [[] for index in xrange(len(args.branches))]
    n_entries_per_iteration = 10000000  # larger buffers make problems
    for iteration in progressiterator.ProgressIterator(
            range(int(math.ceil(n_entries / float(n_entries_per_iteration)))),
            description="Retrieving branch values for sorting"):
        cut = "(Entry$>=({i}*{n}))*(Entry$<(({i}+1)*{n}))".format(
            i=iteration, n=n_entries_per_iteration)
        input_tree.Draw(":".join(args.branches), cut, "goff")
        buffers = [
            input_tree.GetV1(),
            input_tree.GetV2(),
            input_tree.GetV3(),
            input_tree.GetV4()
        ][:len(args.branches)]
        for index, input_buffer in enumerate(buffers):
            values[index].extend(
                list(
                    numpy.ndarray(input_tree.GetSelectedRows(),
                                  dtype=numpy.double,
                                  buffer=input_buffer)))

    log.info("Sorting of the tree entry indices...")
    values = zip(*([range(n_entries)] + values))
    values.sort(key=lambda value: value[1:])
    sorted_entries = list(zip(*values)[0])

    log.info("Creating output " + args.output + "...")
    with tfilecontextmanager.TFileContextManager(args.output,
                                                 "RECREATE") as output_file:
        output_tree = input_tree.CloneTree(0)
        for entry in progressiterator.ProgressIterator(
                sorted_entries, description="Copying tree entries"):
            input_tree.GetEntry(entry)
            output_tree.Fill()
        output_file.Write()
    log.info("Save sorted tree in " + os.path.join(args.output, args.tree) +
             ".")
예제 #15
0
	for index_1 in xrange(args.n_variables):
		for index_2 in xrange(index_1, args.n_variables):
			covariances[index_1][index_2] = sigmas[index_1] * sigmas[index_2]
			if index_1 != index_2:
				covariances[index_2][index_1] = sigmas[index_1] * sigmas[index_2]
				
				correlation = random.uniform(min_correlation, max_correlation)
				covariances[index_1][index_2] *= correlation
				covariances[index_2][index_1] *= correlation
	
	covariance = ROOT.TMatrixDSym(args.n_variables, numpy.array(covariances).flatten())
	log.debug("Covariance matrix:")
	if log.isEnabledFor(logging.DEBUG):
		covariance.Print()
	
	multidim_gaussian = ROOT.RooMultiVarGaussian(
			"multidim_gaussian",
			"multidim_gaussian",
			ROOT.RooArgList(*variables),
			means,
			covariance
	)
	
	ROOT.RooDataSet.setDefaultStorageType(ROOT.RooAbsData.Tree)
	with tfilecontextmanager.TFileContextManager(args.output, "RECREATE") as root_file:
		for tree_name in args.tree_names:
			mc_dataset = multidim_gaussian.generate(ROOT.RooArgSet(*variables), args.n_events)
			mc_dataset.store().tree().Write(tree_name)
	log.info("Created tree(s) \"%s\" in output file \"%s\"." % ("\", \"".join(args.tree_names), args.output))