def configure_sum_factory(config, shared_vars): logger.debug("Configuring sum -> %s", dict(config)) factories = OrderedDict() yields = OrderedDict() global_yield = config.pop('yield', None) for pdf_name, pdf_config in config.items(): # Disable parameter propagation # if 'parameters' not in pdf_config: # pdf_config['parameters'] = OrderedDict() # pdf_config['parameters'].update({param_name: (param_val, None) # for param_name, param_val # in config.get('parameters', {}).items()}) if 'yield' in shared_vars[pdf_name]: yields[pdf_name] = shared_vars[pdf_name].pop('yield') if 'yield' in pdf_config: yield_ = pdf_config.pop('yield') if pdf_name not in yields: yields[pdf_name] = sanitize_parameter( yield_, 'Yield', 'Yield') # yields[pdf_name][0].setStringAttribute('shared', 'true') if isinstance(pdf_config.get('pdf'), str): factories[pdf_name] = configure_model({pdf_name: pdf_config}, shared_vars) else: factories[pdf_name] = configure_model(pdf_config, shared_vars[pdf_name]) logger.debug("Found yields -> %s", yields) if len(factories) == 1: # Set the yield factory_name, factory_obj = list(factories.items())[0] if factory_name in yields: factory_obj.set_yield_var(yields[factory_name]) output_factory = factory_obj else: parameters = {} if (len(factories) - len(yields)) > 1: raise ConfigError( "Missing at least one yield in sum factory definition") elif (len(factories) - len(yields)) == 1: if list(yields.keys())[-1] == list(factories.keys( ))[-1]: # The last one should not have a yield! raise ConfigError( "Wrong order in yield/factory specification") output_factory = factory.SumPhysicsFactory(factories, yields, parameters) if global_yield: output_factory.set_yield_var(global_yield) return output_factory
def replace_globals(folded_data): """Replace values referencing to global, remove global. Args: folded_data (dict): The folded config containing Returns: OrderedDict : *folded_data* with the global keyword removed and every value containing the global keyword replaced by the value. """ GLOBALS_KEYWORD = 'globals' SEP = '.' folded_data = folded_data.copy() # do not mutate arguments # gather globals yaml_globals = folded_data.pop(GLOBALS_KEYWORD, {}) unfolded_data = unfold_config(folded_data) # replace globals for key, val in unfolded_data: if isinstance(val, str) and val.startswith(GLOBALS_KEYWORD + SEP): glob_keys = val.split(SEP)[1:] # remove identifier yaml_global = yaml_globals try: for glob_key in glob_keys: yaml_global = yaml_global[glob_key] except KeyError: raise ConfigError( "Invalid global reference '{}': value {key} not found". format(val, key=key)) unfolded_data.append((key, yaml_global)) return fold_config(unfolded_data, OrderedDict)
def configure_simul_factory(config, shared_vars): logger.debug("Configuring simultaneous -> %s", dict(config)) categories = config['categories'].split(',') \ if isinstance(config['categories'], str) \ else config['categories'] cat_list = [] if len(categories) == 1: cat = ROOT.RooCategory(categories[0], categories[0]) cat_list.append(cat) else: cat_set = ROOT.RooArgSet() for cat_name in categories: cat_list.append(ROOT.RooCategory(cat_name, cat_name)) cat_set.add(cat_list[-1]) cat = ROOT.RooSuperCategory('x'.join(categories), 'x'.join(categories), cat_set) labels = [set() for _ in range(len(categories))] for cat_label in config['pdf'].keys(): for cat_iter, cat_sublabel in enumerate(cat_label.split(',')): cat_sublabel = cat_sublabel.strip() try: if cat_sublabel not in labels[cat_iter]: logger.debug("Registering label for %s -> %s", cat_list[cat_iter].GetName(), cat_sublabel) cat_list[cat_iter].defineType(cat_sublabel) labels[cat_iter].add(cat_sublabel) except IndexError: logger.error( "Mismatch between declared number of categories and label '%s'", cat_label) raise ConfigError( "Badly defined category label '{}'".format(cat_label)) sim_factory = factory.SimultaneousPhysicsFactory( OrderedDict( (tuple(cat_label.replace(' ', '').split(',')), configure_model(cat_config, shared_vars['pdf'][cat_label])) for cat_label, cat_config in config['pdf'].items()), cat) for cat in cat_list: sim_factory.set('cat_{}'.format(cat.GetName()), cat) return sim_factory
def configure_prod_factory(config, shared_vars): logger.debug("Configuring product -> %s", config['pdf']) # Parameter propagated disabled # params = config.get('parameters', {}) # params.update(config['pdf'].pop('parameters', {})) # # Propagate parameters down # for observable, factory_config in config['pdf'].items(): # if 'parameters' not in factory_config: # factory_config['parameters'] = {} # factory_config['parameters'].update(params) if len(config['pdf']) == 1: observable = list(config['pdf'].keys())[0] factory_config = list(config['pdf'].values())[0] return configure_factory(observable, factory_config, shared_vars['pdf'][observable]) else: # Check the yields for child_config in config['pdf'].values(): if 'yield' in child_config: raise ConfigError( "Yield of a RooProductPdf defined in one of the children." ) if shared_vars and 'yield' in shared_vars: config['yield'] = shared_vars['yield'] elif 'yield' in config: config['yield'] = sanitize_parameter(config['yield'], 'Yield', 'Yield') # if 'yield' in config: # config['yield'][0].setStringAttribute('shared', 'true') # Create the product return factory.ProductPhysicsFactory(OrderedDict( (observable, configure_factory(observable, factory_config, shared_vars['pdf'][observable])) for observable, factory_config in config.pop('pdf').items()), parameters=config)
def configure_model(config, shared_vars=None, external_vars=None): """ Raise: ConfigError: If the shared parameters are badly configured. """ def sanitize_parameter(param, name, title): constraint = None if isinstance(param, (list, tuple)): param, constraint = param if not isinstance(param, ROOT.TObject): param, constraint = configure_parameter(name, title, param) return param, constraint def configure_factory(observable, config, shared_vars): logger.debug("Configuring factory -> %s", config) if 'yield' in config: yield_ = config.pop('yield') if 'yield' not in shared_vars: shared_vars['yield'] = sanitize_parameter( yield_, 'Yield', 'Yield') # if 'yield' in shared_vars: # shared_vars['yield'][0].setStringAttribute('shared', 'true') return get_physics_factory(observable, config['pdf'])(config, shared_vars) def configure_prod_factory(config, shared_vars): logger.debug("Configuring product -> %s", config['pdf']) # Parameter propagated disabled # params = config.get('parameters', {}) # params.update(config['pdf'].pop('parameters', {})) # # Propagate parameters down # for observable, factory_config in config['pdf'].items(): # if 'parameters' not in factory_config: # factory_config['parameters'] = {} # factory_config['parameters'].update(params) if len(config['pdf']) == 1: observable = list(config['pdf'].keys())[0] factory_config = list(config['pdf'].values())[0] return configure_factory(observable, factory_config, shared_vars['pdf'][observable]) else: # Check the yields for child_config in config['pdf'].values(): if 'yield' in child_config: raise ConfigError( "Yield of a RooProductPdf defined in one of the children." ) if shared_vars and 'yield' in shared_vars: config['yield'] = shared_vars['yield'] elif 'yield' in config: config['yield'] = sanitize_parameter(config['yield'], 'Yield', 'Yield') # if 'yield' in config: # config['yield'][0].setStringAttribute('shared', 'true') # Create the product return factory.ProductPhysicsFactory(OrderedDict( (observable, configure_factory(observable, factory_config, shared_vars['pdf'][observable])) for observable, factory_config in config.pop('pdf').items()), parameters=config) # parameters={param_name: (param_val, None) # for param_name, param_val in params.items()}) def configure_sum_factory(config, shared_vars): logger.debug("Configuring sum -> %s", dict(config)) factories = OrderedDict() yields = OrderedDict() global_yield = config.pop('yield', None) for pdf_name, pdf_config in config.items(): # Disable parameter propagation # if 'parameters' not in pdf_config: # pdf_config['parameters'] = OrderedDict() # pdf_config['parameters'].update({param_name: (param_val, None) # for param_name, param_val # in config.get('parameters', {}).items()}) if 'yield' in shared_vars[pdf_name]: yields[pdf_name] = shared_vars[pdf_name].pop('yield') if 'yield' in pdf_config: yield_ = pdf_config.pop('yield') if pdf_name not in yields: yields[pdf_name] = sanitize_parameter( yield_, 'Yield', 'Yield') # yields[pdf_name][0].setStringAttribute('shared', 'true') if isinstance(pdf_config.get('pdf'), str): factories[pdf_name] = configure_model({pdf_name: pdf_config}, shared_vars) else: factories[pdf_name] = configure_model(pdf_config, shared_vars[pdf_name]) logger.debug("Found yields -> %s", yields) if len(factories) == 1: # Set the yield factory_name, factory_obj = list(factories.items())[0] if factory_name in yields: factory_obj.set_yield_var(yields[factory_name]) output_factory = factory_obj else: parameters = {} if (len(factories) - len(yields)) > 1: raise ConfigError( "Missing at least one yield in sum factory definition") elif (len(factories) - len(yields)) == 1: if list(yields.keys())[-1] == list(factories.keys( ))[-1]: # The last one should not have a yield! raise ConfigError( "Wrong order in yield/factory specification") output_factory = factory.SumPhysicsFactory(factories, yields, parameters) if global_yield: output_factory.set_yield_var(global_yield) return output_factory def configure_simul_factory(config, shared_vars): logger.debug("Configuring simultaneous -> %s", dict(config)) categories = config['categories'].split(',') \ if isinstance(config['categories'], str) \ else config['categories'] cat_list = [] if len(categories) == 1: cat = ROOT.RooCategory(categories[0], categories[0]) cat_list.append(cat) else: cat_set = ROOT.RooArgSet() for cat_name in categories: cat_list.append(ROOT.RooCategory(cat_name, cat_name)) cat_set.add(cat_list[-1]) cat = ROOT.RooSuperCategory('x'.join(categories), 'x'.join(categories), cat_set) labels = [set() for _ in range(len(categories))] for cat_label in config['pdf'].keys(): for cat_iter, cat_sublabel in enumerate(cat_label.split(',')): cat_sublabel = cat_sublabel.strip() try: if cat_sublabel not in labels[cat_iter]: logger.debug("Registering label for %s -> %s", cat_list[cat_iter].GetName(), cat_sublabel) cat_list[cat_iter].defineType(cat_sublabel) labels[cat_iter].add(cat_sublabel) except IndexError: logger.error( "Mismatch between declared number of categories and label '%s'", cat_label) raise ConfigError( "Badly defined category label '{}'".format(cat_label)) sim_factory = factory.SimultaneousPhysicsFactory( OrderedDict( (tuple(cat_label.replace(' ', '').split(',')), configure_model(cat_config, shared_vars['pdf'][cat_label])) for cat_label, cat_config in config['pdf'].items()), cat) for cat in cat_list: sim_factory.set('cat_{}'.format(cat.GetName()), cat) return sim_factory import analysis.physics.factory as factory # copy: to not alter argument; shallow (not deep!): do not duplicate ROOT objects config = recursive_dict_copy(config, to_copy=(list, tuple)) # Prepare shared variables if shared_vars is None: try: shared_vars = get_shared_vars(config, external_vars) except (ValueError, KeyError) as error: raise ConfigError(error) # Let's find out what is this if 'categories' in config: return configure_simul_factory(config, shared_vars) else: if 'pdf' not in config: indices = list(range(len(config))) try: # remove 'yield' indices.pop(list(config.keys()).index('yield')) except ValueError: # no yield defined index = 0 else: index = indices[0] if isinstance(list(config.values())[index].get('pdf'), str): shared = {'pdf': shared_vars} return configure_prod_factory({'pdf': config}, shared) else: return configure_sum_factory(config, shared_vars) else: if len(config['pdf']) > 1: return configure_prod_factory(config, shared_vars) else: pdf_obs = list(config['pdf'].keys())[0] pdf_config = list(config['pdf'].values())[0] if 'parameters' not in pdf_config: pdf_config['parameters'] = OrderedDict() pdf_config['parameters'].update(config.get('parameters', {})) sh_vars = shared_vars['pdf'][pdf_obs].copy() if 'parameters' in sh_vars: sh_vars['parameters'].update(shared_vars['parameters']) else: sh_vars['parameters'] = shared_vars['parameters'] return configure_factory(observable=pdf_obs, config=pdf_config, shared_vars=sh_vars) raise RuntimeError()
def load_config(*file_names, **options): """Load configuration from YAML files. If more than one is specified, they are loaded in the order given in the function call. Therefore, the latter will override the former if key overlap exists. Currently supported options are: - `root` (str), which determines the node that is considered as root. - `validate` (list), which gets a list of keys to check. If one of these keys is not present, `config.ConfigError` is raised. Additionally, several commands are available to modify the configurations: - The `load` key can be used to load other config files from the config file. The value of this key can have two formats: + `file_name:key` inserts the contents of `key` in `file_name` at the same level as the `load` entry. `file_name` is relative to the lowest common denominator of `file_names`. + `path_func:name:key` inserts the contents `key` in the file obtained by the `get_{path_func}_path(name)` call at the same level as the `load` entry. - The `modify` command can be used to modify a previously loaded key/value pair. It has the format `key: value` and replaces `key` at its same level by the value given by `value`. For more complete examples and documentation, see the README. - The `globals` key can be used to define global variables. Access is via a value written as "globals.path_to.myvar" with a configuration like: {globals: {path_to: {myvar: myval}},....}. This will replace it with `myval`. Arguments: *file_names (list[str]): Files to load. **options (dict): Configuration options. See above for supported options. Return: dict: Configuration. Raise: OSError: If some file does not exist. ConfigError: If key loading or validation fail. """ unfolded_data = [] for file_name in file_names: if not os.path.exists(file_name): raise OSError("Cannot find config file -> {}".format(file_name)) try: with open(file_name) as input_obj: unfolded_data.extend( unfold_config( yaml.load(input_obj, Loader=yamlloader.ordereddict.CLoader))) except yaml.parser.ParserError as error: raise KeyError(str(error)) # Load required data unfolded_data_expanded = [] root_prev_load = None for key, val in unfolded_data: command = key.split('/')[-1] if command == 'load': # An input requirement has been made split_val = val.split(":") if len(split_val) == 2: # file_name:key format file_name_result, required_key = split_val if not os.path.isabs(file_name_result): if len(file_names) == 1: file_name_result = os.path.join( os.path.split(file_names[0])[0], file_name_result) else: file_name_result = os.path.join( os.path.commonprefix(*file_names), file_name_result) elif len(split_val) == 3: # path_func:name:key format path_name, name, required_key = split_val import analysis.utils.paths as _paths try: path_func = getattr(_paths, 'get_{}_path'.format(path_name)) except AttributeError: raise ConfigError( "Unknown path getter type -> {}".format(path_name)) file_name_result = path_func(name) else: raise ConfigError("Malformed 'load' key") try: root = key.rsplit('/load')[0] for new_key, new_val in unfold_config( load_config(file_name_result, root=required_key)): unfolded_data_expanded.append( ('{}/{}'.format(root, new_key), new_val)) except Exception: logger.error("Error loading required data in %s", required_key) raise else: root_prev_load = root elif root_prev_load and key.startswith( root_prev_load): # we have to handle it *somehow* relative_key = key.split(root_prev_load + '/', 1)[1] # remove root if not relative_key.startswith('modify/'): logger.error( "Key % cannot be used without 'modify' if 'load' came before.", key) raise ConfigError( "Loaded pdf with 'load' can *only* be modified by using 'modify'." ) key_to_replace = '{}/{}'.format( root_prev_load, relative_key.split('modify/', 1)[1]) try: key_index = [key for key, _ in unfolded_data_expanded ].index(key_to_replace) except IndexError: logger.error("Cannot find key to modify -> %s", key_to_replace) raise ConfigError("Malformed 'modify' key") unfolded_data_expanded[key_index] = (key_to_replace, val) else: root_prev_load = None # reset, there was no 'load' unfolded_data_expanded.append((key, val)) # Fold back data = fold_config(unfolded_data_expanded, OrderedDict) # Replace globals data = replace_globals(data) logger.debug('Loaded configuration -> %s', data) data_root = options.get('root', '') if data_root: for root_node in data_root.split('/'): try: data = data[root_node] except KeyError: raise ConfigError( "Root node {} of {} not found in dataset".format( root_node, data_root)) if 'validate' in options: missing_keys = [] data_keys = [ '/'.join(key.split('/')[:entry_num + 1]) for key, _ in unfolded_data for entry_num in range(len(key.split('/'))) ] logger.debug("Validating against the following keys -> %s", ', '.join(data_keys)) for key in options['validate']: key = os.path.join(data_root, key) if key not in data_keys: missing_keys.append(key) if missing_keys: raise ConfigError( "Failed validation: {} are missing".format( ','.join(missing_keys)), missing_keys) return data