示例#1
0
def get_output(*args, **kwargs):
    """
    Auxiliary function to retrieve the output driver
    (e.g. whether to get the MPI-wrapped one, or a dummy output driver).
    """
    if kwargs.get("output_prefix"):
        from cobaya.mpi import import_MPI
        return import_MPI(".output", "Output")(*args, **kwargs)
    else:
        return OutputDummy(*args, **kwargs)
示例#2
0
def get_Output(*args, **kwargs):
    """
    Auxiliary function to retrieve the output driver.
    """
    if kwargs.get("output_prefix"):
        from cobaya.mpi import import_MPI
        Output = import_MPI(".output", "Output")
        return Output(*args, **kwargs)
    else:
        return OutputDummy(*args, **kwargs)
示例#3
0
文件: output.py 项目: yufdu/cobaya
def get_output(*args, **kwargs):
    """
    Auxiliary function to retrieve the output driver
    (e.g. whether to get the MPI-wrapped one, or a dummy output driver).
    """
    # MARKED FOR DEPRECATION IN v3.0
    if kwargs.get("output_prefix") is not None:
        kwargs["prefix"] = kwargs["output_prefix"]
    # END OF DEPRECATION BLOCK
    if kwargs.get("prefix"):
        from cobaya.mpi import import_MPI
        return import_MPI(".output", "Output")(*args, **kwargs)
    else:
        return OutputDummy(*args, **kwargs)
示例#4
0
def run_script():
    import os
    import argparse
    from cobaya.conventions import _path_install, _debug
    parser = argparse.ArgumentParser(description="Cobaya's run script.")
    parser.add_argument("input_file",
                        nargs=1,
                        action="store",
                        metavar="input_file.yaml",
                        help="An input file to run.")
    parser.add_argument(
        "-p",
        "--path",
        action="store",
        nargs="+",
        metavar="/some/path",
        help="Path where modules were automatically installed.")
    parser.add_argument("--" + _debug,
                        action="store_true",
                        help="Set this flag for debug output.")
    args = parser.parse_args()
    if any([(os.path.splitext(f)[0] in ("input", "full"))
            for f in args.input_file]):
        raise ValueError("'input' and 'full' are reserved file names. "
                         "Please, use a different one.")
    from cobaya.mpi import import_MPI
    load_input = import_MPI(".input", "load_input")
    info = load_input(args.input_file[0])
    # solve path
    path_env = os.environ.get("COBAYA_MODULES", None)
    path_cmd = (lambda x: x[0] if x else None)(getattr(args, "path"))
    path_input = info.get(_path_install)
    if path_cmd and path_input:
        raise ValueError(
            "*CONFLICT* You have specified a modules folder both in the command line "
            "('%s') and the input file ('%s'). There should only be one." %
            (path_cmd, path_input))
    info[_path_install] = path_input or (path_cmd or path_env)
    if getattr(args, _debug):
        info[_debug] = True
    run(info)
示例#5
0
def run_script():
    warn_deprecation()
    import os
    import argparse
    parser = argparse.ArgumentParser(description="Cobaya's run script.")
    parser.add_argument("input_file",
                        nargs=1,
                        action="store",
                        metavar="input_file.yaml",
                        help="An input file to run.")
    parser.add_argument("-" + _packages_path_arg[0],
                        "--" + _packages_path_arg_posix,
                        action="store",
                        nargs=1,
                        metavar="/packages/path",
                        default=[None],
                        help="Path where external packages were installed.")
    # MARKED FOR DEPRECATION IN v3.0
    modules = "modules"
    parser.add_argument("-" + modules[0],
                        "--" + modules,
                        action="store",
                        nargs=1,
                        required=False,
                        metavar="/packages/path",
                        default=[None],
                        help="To be deprecated! "
                        "Alias for %s, which should be used instead." %
                        _packages_path_arg_posix)
    # END OF DEPRECATION BLOCK -- CONTINUES BELOW!
    parser.add_argument("-" + _output_prefix[0],
                        "--" + _output_prefix,
                        action="store",
                        nargs=1,
                        metavar="/some/path",
                        default=[None],
                        help="Path and prefix for the text output.")
    parser.add_argument("-" + _debug[0],
                        "--" + _debug,
                        action="store_true",
                        help="Produce verbose debug output.")
    continuation = parser.add_mutually_exclusive_group(required=False)
    continuation.add_argument(
        "-" + _resume[0],
        "--" + _resume,
        action="store_true",
        help="Resume an existing chain if it has similar info "
        "(fails otherwise).")
    continuation.add_argument("-" + _force[0],
                              "--" + _force,
                              action="store_true",
                              help="Overwrites previous output, if it exists "
                              "(use with care!)")
    parser.add_argument("--%s" % _test_run,
                        action="store_true",
                        help="Initialize model and sampler, and exit.")
    parser.add_argument("--version", action="version", version=__version__)
    parser.add_argument("--no-mpi",
                        action='store_true',
                        help="disable MPI when mpi4py installed but MPI does "
                        "not actually work")
    arguments = parser.parse_args()
    if arguments.no_mpi or getattr(arguments, _test_run, False):
        set_mpi_disabled()
    if any((os.path.splitext(f)[0] in ("input", "updated"))
           for f in arguments.input_file):
        raise ValueError("'input' and 'updated' are reserved file names. "
                         "Please, use a different one.")
    load_input = import_MPI(".input", "load_input")
    given_input = arguments.input_file[0]
    if any(given_input.lower().endswith(ext) for ext in _yaml_extensions):
        info = load_input(given_input)
        output_prefix_cmd = getattr(arguments, _output_prefix)[0]
        output_prefix_input = info.get(_output_prefix)
        info[_output_prefix] = output_prefix_cmd or output_prefix_input
    else:
        # Passed an existing output_prefix? Try to find the corresponding *.updated.yaml
        updated_file = get_info_path(*split_prefix(given_input),
                                     kind="updated")
        try:
            info = load_input(updated_file)
        except IOError:
            raise ValueError(
                "Not a valid input file, or non-existent run to resume")
        # We need to update the output_prefix to resume the run *where it is*
        info[_output_prefix] = given_input
        # If input given this way, we obviously want to resume!
        info[_resume] = True
    # solve packages installation path cmd > env > input
    # MARKED FOR DEPRECATION IN v3.0
    if getattr(arguments, modules) != [None]:
        logger_setup()
        logger = logging.getLogger(__name__.split(".")[-1])
        logger.warning(
            "*DEPRECATION*: -m/--modules will be deprecated in favor of "
            "-%s/--%s in the next version. Please, use that one instead.",
            _packages_path_arg[0], _packages_path_arg_posix)
        # BEHAVIOUR TO BE REPLACED BY ERROR:
        if getattr(arguments, _packages_path_arg) == [None]:
            setattr(arguments, _packages_path_arg, getattr(arguments, modules))
    # BEHAVIOUR TO BE REPLACED BY ERROR:
    check_deprecated_modules_path(info)
    # END OF DEPRECATION BLOCK
    info[_packages_path] = \
        getattr(arguments, _packages_path_arg)[0] or info.get(_packages_path)
    info[_debug] = getattr(arguments, _debug) or info.get(
        _debug, _debug_default)
    info[_test_run] = getattr(arguments, _test_run, False)
    # If any of resume|force given as cmd args, ignore those in the input file
    resume_arg, force_arg = [
        getattr(arguments, arg) for arg in [_resume, _force]
    ]
    if any([resume_arg, force_arg]):
        info[_resume], info[_force] = resume_arg, force_arg
    if _post in info:
        post(info)
    else:
        run(info)
示例#6
0
def run_script():
    warn_deprecation()
    import os
    import argparse
    parser = argparse.ArgumentParser(description="Cobaya's run script.")
    parser.add_argument("input_file",
                        nargs=1,
                        action="store",
                        metavar="input_file.yaml",
                        help="An input file to run.")
    parser.add_argument(
        "-" + _modules_path_arg[0],
        "--" + _modules_path_arg,
        action="store",
        nargs=1,
        metavar="/some/path",
        default=[None],
        help="Path where modules were automatically installed.")
    parser.add_argument("-" + _output_prefix[0],
                        "--" + _output_prefix,
                        action="store",
                        nargs=1,
                        metavar="/some/path",
                        default=[None],
                        help="Path and prefix for the text output.")
    parser.add_argument("-" + _debug[0],
                        "--" + _debug,
                        action="store_true",
                        help="Produce verbose debug output.")
    continuation = parser.add_mutually_exclusive_group(required=False)
    continuation.add_argument(
        "-" + _resume[0],
        "--" + _resume,
        action="store_true",
        help="Resume an existing chain if it has similar info "
        "(fails otherwise).")
    continuation.add_argument("-" + _force[0],
                              "--" + _force,
                              action="store_true",
                              help="Overwrites previous output, if it exists "
                              "(use with care!)")
    parser.add_argument("--version", action="version", version=__version__)
    args = parser.parse_args()
    if any([(os.path.splitext(f)[0] in ("input", "updated"))
            for f in args.input_file]):
        raise ValueError("'input' and 'updated' are reserved file names. "
                         "Please, use a different one.")
    load_input = import_MPI(".input", "load_input")
    given_input = args.input_file[0]
    if any(given_input.lower().endswith(ext) for ext in _yaml_extensions):
        info = load_input(given_input)
        output_prefix_cmd = getattr(args, _output_prefix)[0]
        output_prefix_input = info.get(_output_prefix)
        info[_output_prefix] = output_prefix_cmd or output_prefix_input
    else:
        # Passed an existing output_prefix? Try to find the corresponding *.updated.yaml
        updated_file = (
            given_input +
            (_separator_files if not given_input.endswith(os.sep) else "") +
            _updated_suffix + _yaml_extensions[0])
        try:
            info = load_input(updated_file)
        except IOError:
            raise ValueError(
                "Not a valid input file, or non-existent sample to resume")
        # We need to update the output_prefix to resume the sample *where it is*
        info[_output_prefix] = given_input
        # If input given this way, we obviously want to resume!
        info[_resume] = True
    # solve modules installation path cmd > env > input
    path_cmd = getattr(args, _modules_path_arg)[0]
    path_env = os.environ.get(_modules_path_env, None)
    path_input = info.get(_path_install)
    info[_path_install] = path_cmd or (path_env or path_input)
    info[_debug] = getattr(args, _debug) or info.get(_debug, _debug_default)
    info[_resume] = getattr(args, _resume, _resume_default)
    info[_force] = getattr(args, _force, False)
    if _post in info:
        post(info)
    else:
        run(info)
示例#7
0
def run(info):

    assert hasattr(info, "items"), (
        "The agument of `run` must be a dictionary with the info needed for the run. "
        "If you were trying to pass an input file instead, load it first with "
        "`cobaya.input.load_input`.")

    # Import names
    from cobaya.conventions import _likelihood, _prior, _params
    from cobaya.conventions import _theory, _sampler, _path_install
    from cobaya.conventions import _debug, _debug_file, _output_prefix

    # Configure the logger ASAP
    from cobaya.log import logger_setup
    logger_setup(info.get(_debug), info.get(_debug_file))

    # Debug (lazy call)
    import logging
    if logging.root.getEffectiveLevel() <= logging.DEBUG:
        # Don't dump unless we are doing output, just in case something not serializable
        # May be fixed in the future if we find a way to serialize external functions
        if info.get(_output_prefix):
            from cobaya.yaml import yaml_dump
            logging.getLogger(__name__.split(".")[-1]).debug(
                "Input info (dumped to YAML):\n%s", yaml_dump(info))

    # Import general classes
    from cobaya.prior import Prior
    from cobaya.sampler import get_Sampler as Sampler

    # Import the functions and classes that need MPI wrapping
    from cobaya.mpi import import_MPI
    #    Likelihood = import_MPI(".likelihood", "LikelihoodCollection")
    from cobaya.likelihood import LikelihoodCollection as Likelihood

    # Initialise output, if requiered
    do_output = info.get(_output_prefix)
    if do_output:
        Output = import_MPI(".output", "Output")
        output = Output(info)
    else:
        from cobaya.output import Output_dummy
        output = Output_dummy(info)

    # Create the full input information, including defaults for each module.
    from cobaya.input import get_full_info
    full_info = get_full_info(info)
    if logging.root.getEffectiveLevel() <= logging.DEBUG:
        # Don't dump unless we are doing output, just in case something not serializable
        # May be fixed in the future if we find a way to serialize external functions
        if info.get(_output_prefix):
            logging.getLogger(__name__.split(".")[-1]).debug(
                "Updated info (dumped to YAML):\n%s", yaml_dump(full_info))
    # We dump the info now, before modules initialization, lest it is accidentaly modified
    output.dump_info(info, full_info)

    # Set the path of the installed modules, if given
    from cobaya.tools import set_path_to_installation
    set_path_to_installation(info.get(_path_install))

    # Initialise parametrization, likelihoods and prior
    from cobaya.parametrization import Parametrization
    with Parametrization(full_info[_params]) as par:
        with Prior(par, full_info.get(_prior)) as prior:
            with Likelihood(full_info[_likelihood], par,
                            full_info.get(_theory)) as lik:
                with Sampler(full_info[_sampler], par, prior, lik,
                             output) as sampler:
                    sampler.run()

    # For scripted calls
    return deepcopy(full_info), sampler.products()