예제 #1
0
def register_config_action(name, action_function):
    """Register a config action keyword.

    The action gets four parameters:
        - The name of the parameter
        - The title of the parameter
        - The parameter configuration
        - The list of shared vars.

    Arguments:
        name (str): Name of the action.
        action_function (Callable): Parameter config action.

    Return:
        int: Number of registered config actions.

    Raise:
        ValueError: If the configuration action function doesn't have the correct number of parameters.

    """
    try:  # PY3
        action_function_args = inspect.getfullargspec(action_function).args
    except AttributeError:  # PY2
        action_function_args = inspect.getargspec(action_function).args
    if len(action_function_args) != 4:
        raise ValueError(
            "The action configuration function needs to have 4 arguments")
    logger.debug("Registering %s parameter configuration keyword", name)
    get_global_var('PARAMETER_KEYWORDS').update({name: action_function})
    return len(get_global_var('PARAMETER_KEYWORDS'))
예제 #2
0
def register_file_type(extension, file_type):
    """Register a file type with a given extension.

    Return:
        int: Number of extensions registered with that particular file type.

    """
    if not extension.startswith('.'):
        extension = "." + extension
    logger.debug("Registering file type for extension %s -> %s", extension,
                 file_type)
    get_global_var('FILE_TYPES').update({extension: file_type})
    return sum(1 for f_type in get_global_var('FILE_TYPES').values()
               if f_type == file_type)
예제 #3
0
def register_efficiency_model(model_name, model_class):
    """Register an efficiency model.

    This model then becomes available to `get_efficiency_model` functions.

    Arguments:
        model_name (str): Name of the model.
        model_class (`Efficiency`): Efficiency model to register.

    Return:
        int: Number of registered efficiency models

    """
    logger.debug("Registering efficiency model -> %s", model_name)
    get_global_var('EFFICIENCY_MODELS').update({model_name: model_class})
    return len(get_global_var('EFFICIENCY_MODELS'))
예제 #4
0
def _get_path(dirs, extension, name_transformation, name, *args, **kwargs):
    """Get the path for an object.

    The path is $BASE_PATH/{'/'.join(dirs)}/{name_transformation(name, args, kwargs)}{extension}.

    Arguments:
        dirs (list): Parent directories of the object path.
        extension (str): Extension of the file (including the dot). Can be empty for
            directories.
        name_transformation (Callable, optional): Function to transform the name of the path.
        name (str): Name of the object.
        *args (list): Positional arguments to be passed to `name_transformation`.
        *kwargs (list): Keyword arguments to be passed to `name_transformation`.

    Return:
        str: Absolute path of the object.

    """
    assert not extension or extension.startswith('.'), \
        "Extension is expected to start with '.'. Given extension: {}".format(extension)

    path = os.path.join(*([get_global_var('BASE_PATH')] +
                          dirs + [name_transformation(name, args, kwargs)]))
    if not path.endswith(extension):
        path += extension

    return path
예제 #5
0
def register_physics_factories(observable, factories):
    """Register a physics factory.

    This model then becomes available to `get_efficiency_model` functions.

    Arguments:
        observable (str): Observable name.
        factories (dict): Factory name -> factory class mapping.

    Return:
        int: Number of registered physics factories for the given observable.

    """
    logger.debug("Registering factories for the '%s' observable -> %s",
                 observable, factories)
    get_global_var('PHYSICS_FACTORIES')[observable].update(factories)
    return len(get_global_var('PHYSICS_FACTORIES')[observable])
예제 #6
0
def get_efficiency_model_class(model_name):
    """Load the efficiency class.

    Arguments:
        model_name (str): Name of the efficiency model class.

    Return:
        `Efficiency`: Efficiency class, non-instantiated.

    """
    return get_global_var('EFFICIENCY_MODELS').get(model_name.lower())
예제 #7
0
def register_toy_randomizer(name, rand_class):
    """Register a randomized toy generator.

    Randomizers are registered in the `TOY_RANDOMIZERS` global variable.

    Arguments:
        name (str): Name of the randomizer.
        rand_class (ToyRandomizer): Randomizer class to register.

    Return:
        int: Number of registered randomizers.

    Raise:
        ValueError: If `rand_class` is not of the correct type.

    """
    from analysis.toys.randomizers import ToyRandomizer
    logger.debug("Registering %s toy randomizer", name)
    if not issubclass(rand_class, ToyRandomizer):
        raise ValueError("Wrong class type -> {}".format(type(rand_class)))
    get_global_var('TOY_RANDOMIZERS').update({name: rand_class})
    return len(get_global_var('TOY_RANDOMIZERS'))
예제 #8
0
def get_config_action(name):
    """Get a configuration action.

    Arguments:
        name (str): Name of the configuration keyword.

    Return:
        Callable: The parameter configuration function.

    Raise:
        KeyError: If the keyword is not registered.

    """
    return get_global_var('PARAMETER_KEYWORDS')[name]
예제 #9
0
def add_pdf_paths(*paths):
    """Add path to the global 'PDF_PATHS' variable if not already there.

    The inserted paths take preference.

    Note:
        If any of the paths is relative, it is built in relative to
        `BASE_PATH`.

    Arguments:
        *paths (list): List of paths to append.

    Return:
        list: Updated PDF paths.

    """
    base_path = get_global_var('BASE_PATH')
    for path in reversed(paths):
        if not os.path.isabs(path):
            path = os.path.abspath(os.path.join(base_path, path))
        if path not in get_global_var('PDF_PATHS'):
            _logger.debug("Adding %s to PDF_PATHS", path)
            get_global_var('PDF_PATHS').insert(0, path)
    return get_global_var('PDF_PATHS')
예제 #10
0
def get_randomizer(rand_config):
    """Load randomized toy generator.

    The randomizer type is specified through the `type` key.

    Arguments:
        rand_config (dict): Configuration of toy randomizer.

    Return:
        ToyRandomizer class

    Raise:
        KeyError: If the randomizer type is unknown.

    """
    return get_global_var('TOY_RANDOMIZERS')[rand_config['type']]
예제 #11
0
def get_physics_factory(observable, pdf_type):
    """Get physics factory.

    Arguments:
        observable (str): Observable name.
        pdf_type (str): Type of the pdf.

    Return:
        `PhysicsFactory`: Requested PhysicsFactory.

    Raise:
        KeyError: If the type of factory is unknown.

    """
    factories = get_global_var('PHYSICS_FACTORIES')
    if observable not in factories:
        raise KeyError("Unknown observable type -> {}".format(observable))
    if pdf_type not in factories[observable]:
        raise KeyError("Unknown PDF type -> {}".format(pdf_type))
    return factories[observable][pdf_type]
예제 #12
0
def prepare_path(name, path_func, link_from, *args, **kwargs):
    """Build the folder structure for any output.

    The output file name is obtained from `path_func` and the possibility of
    having soft links is taken into account through the `link_from` argument.

    It takes the output file_name and builds all the folder structure from
    `BASE_PATH` until that file. If soft links have to be taken into account,
    the same relative path is built from the `link_from` folder.

    Arguments:
        name (str): Name of the job. To be passed to `path_func`.
        path_func (Callable): Function to execute to get the path.
        link_from (str): Base directory for symlinking. If `None`, no symlinking
            is done.
        *args (list): Extra arguments for the `path_func`.
        **kwargs (dict): Extra arguments for the `path_func`.

    Return:
        tuple (bool, str, str): Need to do soft-linking, path of true output file,
            path of soft-link output.

    """
    do_link = False
    dest_base_dir = get_global_var('BASE_PATH')
    src_base_dir = link_from or dest_base_dir
    if dest_base_dir != src_base_dir:
        do_link = True
        if not os.path.exists(src_base_dir):
            raise OSError(
                "Cannot find storage folder -> {}".format(src_base_dir))
    dest_file_name = path_func(name, *args, **kwargs)
    rel_file_name = os.path.relpath(dest_file_name, dest_base_dir)
    src_file_name = os.path.join(src_base_dir, rel_file_name)
    # Create dirs
    rel_dir = rel_file_name if os.path.isdir(
        rel_file_name) else os.path.dirname(rel_file_name)
    for dir_ in (dest_base_dir, src_base_dir):
        if not os.path.exists(os.path.join(dir_, rel_dir)):
            os.makedirs(os.path.join(dir_, rel_dir))
    return do_link, src_file_name, dest_file_name
예제 #13
0
def load_pdf_by_name(name, use_mathmore=False):
    """Load the given PDF using its name.

    It's compiled if needed.

    Arguments:
        name (str): Name of the PDF to load.
        use_mathmore (bool, optional): Load libMathMore before compiling.
            Defaults to False.

    Return:
        `ROOT.RooAbsPdf`: RooFit PDF object.

    Raise:
        OSError: If the .cc file corresponding to `name` cannot be found.

    """
    try:
        _load_library(name,
                      lib_dirs=get_global_var('PDF_PATHS'),
                      use_mathmore=use_mathmore)
    except OSError:
        raise OSError("Don't know this PDF! -> {}".format(name))
    return getattr(ROOT, os.path.splitext(os.path.split(name)[1])[0])
예제 #14
0
def main():
    """Toy fitting submission application.

    Parses the command line, configures the toy fitters and submit the
    jobs, catching intermediate errors and transforming them to status codes.

    Status codes:
        0: All good.
        1: Error in the configuration files.
        2: Error in preparing the output folders.
        3: Conflicting options given.
        4: A non-matching configuration file was found in the output.
        5: The queue submission command cannot be found.
        128: Uncaught error. An exception is logged.

    """

    def flatten(list_, typ_):
        """Flatten a list."""
        return list(sum(list_, typ_))

    parser = argparse.ArgumentParser()
    parser.add_argument('-v', '--verbose',
                        action='store_true',
                        help="Verbose output")
    parser.add_argument('--link-from',
                        action='store', type=str,
                        help="Folder to actually store the toy files")
    parser.add_argument('--extend',
                        action='store_true', default=False,
                        help="Extend previous production")
    parser.add_argument('--overwrite',
                        action='store_true', default=False,
                        help="Overwrite previous production")
    parser.add_argument('config',
                        action='store', type=str, nargs='+',
                        help="Configuration file")
    args = parser.parse_args()
    if args.verbose:
        get_logger('analysis').setLevel(1)
        logger.setLevel(1)
    try:
        config = _config.load_config(*args.config)
        # Which type of toy are we running?
        script_to_run = None
        submitter = None
        for toy_type, (toy_class, script_name) in TOY_TYPES.items():
            if toy_type in config:
                script_to_run = script_name
                submitter = toy_class
        if submitter is None:
            raise KeyError("Unknown job type")
        # Is there something to scan?
        scan_config = 'scan' in config
        if scan_config:
            config_files = []
            base_config = _config.unfold_config(config)
            scan_groups = []
            for scan_group in config['scan']:
                scan_group_dict = {}
                for key, val_str in scan_group.items():
                    scan_group_dict[key] = process_scan_val(val_str, scan_group_dict)
                scan_groups.append(scan_group_dict)
            # Check lengths
            if not all(len({len(val) for val in scan_group.values()}) == 1
                       for scan_group in scan_groups):
                raise ValueError("Unmatched length in scan parameters")
            # Build values to scan
            keys, values = list(zip(*[zip(*scan_group.items()) for scan_group in scan_groups]))
            keys = flatten(keys, tuple())
            for value_tuple in itertools.product(*[zip(*val) for val in values]):
                values = dict(zip(keys, flatten(value_tuple, tuple())))
                temp_config = dict(base_config)
                del temp_config['scan']
                temp_config['name'] = temp_config['name'].format(**values)
                for key, value in values.items():
                    temp_config[key] = value
                logger.debug("Creating configuration %s for scan values -> %s",
                             temp_config['name'],
                             ", ".join('{}: {}'.format(*val) for val in values.items()))
                # Write temp_file
                with tempfile.NamedTemporaryFile(delete=False) as file_:
                    file_name = file_.name
                _config.write_config(_config.fold_config(list(temp_config.items())), file_name)
                config_files.append(file_name)
        else:
            config_files = args.config
    # pylint: disable=W0702
    except:
        logger.exception("Bad configuration given")
        parser.exit(1)
    try:
        script_to_run = os.path.join(get_global_var('BASE_PATH'),
                                     'toys',
                                     script_to_run)
        for config_file in config_files:
            submitter(config_files=[config_file],
                      link_from=args.link_from,
                      extend=args.extend,
                      overwrite=args.overwrite,
                      verbose=args.verbose).run(script_to_run, )
            if scan_config:
                os.remove(config_file)
        exit_status = 0
    except KeyError:
        logger.error("Bad configuration given")
        exit_status = 1
    except OSError as error:
        logger.error(str(error))
        exit_status = 2
    except ValueError:
        logger.error("Conflicting options found")
        exit_status = 3
    except AttributeError:
        logger.error("Mismatching configuration found")
        exit_status = 4
    except AssertionError:
        logger.error("Cannot find the queue submission command")
        exit_status = 5
    # pylint: disable=W0703
    except Exception as error:
        exit_status = 128
        logger.exception('Uncaught exception -> %s', repr(error))
    finally:
        parser.exit(exit_status)
예제 #15
0
def get_data(data_config, **kwargs):
    """Get data.

    Detects the input file extension and uses the proper loader.

    The required configuration keys are:
        + `source`: Input source. If `source-type` is specified, the file name will
            be obtained executing `get_{source-type}_path`, otherwise `source` is
            treated as a file name.
        + `tree`: Tree within the file.
        + `output-format`: Type of data we want. Currently `root` or `pandas`.

    Optional config keys:
        + `input-type`: type of input, in case the extension has not been registered.

    Raise:
        AttributeError: If the specified source type is unknown.
        KeyError: If the input file extension is not recognized.
        FileNotFoundError: If the input file can't be found.
        ValueError: If the requested output format is not available for the input.

    """
    import analysis.data.loaders as _loaders
    # Do we need to merge?
    if isinstance(data_config, list):
        from analysis.data.mergers import merge
        logger.debug("Multiple datasets specified. Merging...")
        return merge([get_data(data) for data in data_config], **kwargs)
    # Merge data_config and keyword arguments
    data_config.update(kwargs)
    # Check the configuration
    for key in ('source', 'output-format'):
        if key not in data_config:
            raise KeyError(
                "Bad data configuration -> '{}' key is missing".format(key))
    source_name = data_config.pop('source')
    try:
        source_type = data_config.pop('source-type', None)
        file_name = source_name if not source_type \
            else getattr(paths, 'get_{}_path'.format(source_type))(source_name)
        if not os.path.exists(file_name):
            raise FileNotFoundError(
                "Cannot find input file -> {}".format(file_name))
    except AttributeError as error:
        logger.error("Unknows source type -> %s, original error %s",
                     source_type, error)
        raise DataError("Unknown source type -> {}".format(source_type))
    tree_name = data_config.pop('tree', '')
    output_format = data_config.pop('output-format').lower()
    # Optional: output-type, cuts, branches
    input_ext = os.path.splitext(file_name)[1]
    try:
        input_type = data_config.get('input-type')
        if not input_type:
            input_type = get_global_var('FILE_TYPES')[input_ext]
    except KeyError:
        raise KeyError(
            "Unknown file extension -> {}. Cannot load file.".format(
                input_ext))
    try:
        get_data_func = getattr(
            _loaders, 'get_{}_from_{}_file'.format(output_format, input_type))
    except AttributeError:
        raise ValueError("Output format unavailable for input file"
                         "with extension {} -> {}".format(
                             input_ext, output_format))
    logger.debug("Loading data file -> %s:%s", file_name, tree_name)
    return get_data_func(file_name, tree_name, data_config)
예제 #16
0
def run(config_files, link_from):
    """Run the script.

    If the efficiency file exists, only the plots are remade.

    Arguments:
        config_files (list[str]): Path to the configuration files.
        link_from (str): Path to link the results from.

    Raise:
        OSError: If there either the configuration file does not exist some
            of the input files cannot be found.
        KeyError: If some configuration data are missing.
        ValueError: If there is any problem in configuring the efficiency model.
        RuntimeError: If there is a problem during the efficiency fitting.

    """
    try:
        config = _config.load_config(*config_files,
                                     validate=[
                                         'name', 'data/source', 'data/tree',
                                         'parameters', 'model', 'variables'
                                     ])
    except OSError:
        raise OSError(
            "Cannot load configuration files: {}".format(config_files))
    except _config.ConfigError as error:
        if 'name' in error.missing_keys:
            logger.error("No name was specified in the config file!")
        if 'data/file' in error.missing_keys:
            logger.error("No input data specified in the config file!")
        if 'data/tree' in error.missing_keys:
            logger.error("No input data specified in the config file!")
        if 'model' in error.missing_keys:
            logger.error("No efficiency model specified in the config file!")
        if 'parameters' in error.missing_keys:
            logger.error(
                "No efficiency model parameters specified in the config file!")
        if 'variables' in error.missing_keys:
            logger.error(
                "No efficiency variables to model have been specified in the config file!"
            )
        raise KeyError("ConfigError raised -> {}".format(error.missing_keys))
    except KeyError as error:
        logger.error("YAML parsing error -> %s", error)
        raise
    # Do checks and load things
    plot_files = {}
    if config.get('plot', False):
        for var_name in config['variables']:
            plot_files[var_name] = get_efficiency_plot_path(config['name'],
                                                            var=var_name)
    efficiency_class = get_efficiency_model_class(config['model'])
    if not efficiency_class:
        raise ValueError("Unknown efficiency model -> {}".format(
            config['model']))
    # Let's do it
    # pylint: disable=E1101
    if not all(os.path.exists(file_name)
               for file_name in plot_files.values()) or \
            not os.path.exists(_paths.get_efficiency_path(config['name'])):  # If plots don't exist, we load data
        logger.info("Loading data, this may take a while...")
        weight_var = config['data'].get('weight-var-name', None)
        # Prepare data
        config['data']['output-format'] = 'pandas'
        config['data']['variables'] = list(config['variables'])
        if weight_var:
            config['data']['variables'].append(weight_var)
        input_data = get_data(config['data'], **{'output-format': 'pandas'})
        if weight_var:
            logger.info("Data loaded, using %s as weight", weight_var)
        else:
            logger.info("Data loaded, not using any weights")

        if not os.path.exists(_paths.get_efficiency_path(config['name'])):
            logger.info("Fitting efficiency model")
            try:
                eff = efficiency_class.fit(input_data, config['variables'],
                                           weight_var, **config['parameters'])
            except (ValueError, TypeError) as error:
                raise ValueError(
                    "Cannot configure the efficiency model -> {}".format(
                        error.message))
            except KeyError as error:
                raise RuntimeError("Missing key -> {}".format(error))
            except Exception as error:
                raise RuntimeError(error)
            output_file = eff.write_to_disk(config['name'], link_from)
            logger.info("Written efficiency file -> %s", output_file)
        else:
            logger.warning(
                "Output efficiency already exists, only redoing plots")
            eff = load_efficiency_model(config['name'])
        if plot_files:
            import seaborn as sns
            sns.set_style("white")
            plt.style.use('file://{}'.format(
                os.path.join(get_global_var('STYLE_PATH'),
                             'matplotlib_LHCb.mplstyle')))
            plots = eff.plot(input_data,
                             weight_var,
                             labels=config.get('plot-labels', {}))
            for var_name, plot in plots.items():
                logger.info("Plotting '%s' efficiency -> %s", var_name,
                            plot_files[var_name])
                plot.savefig(plot_files[var_name], bbox_inches='tight')
    else:
        logger.info("Efficiency file exists: %s. Nothing to do!",
                    _paths.get_efficiency_path(config['name']))
예제 #17
0
    Return:
        list: Updated PDF paths.

    """
    base_path = get_global_var('BASE_PATH')
    for path in reversed(paths):
        if not os.path.isabs(path):
            path = os.path.abspath(os.path.join(base_path, path))
        if path not in get_global_var('PDF_PATHS'):
            _logger.debug("Adding %s to PDF_PATHS", path)
            get_global_var('PDF_PATHS').insert(0, path)
    return get_global_var('PDF_PATHS')


# Default PDF paths: analysis/pdfs and module/pdfs
add_pdf_paths(os.path.join(get_global_var('ANALYSIS_PATH'), 'pdfs'),
              'pdfs')


def load_pdf_by_name(name, use_mathmore=False):
    """Load the given PDF using its name.

    It's compiled if needed.

    Arguments:
        name (str): Name of the PDF to load.
        use_mathmore (bool, optional): Load libMathMore before compiling.
            Defaults to False.

    Return:
        `ROOT.RooAbsPdf`: RooFit PDF object.