Exemplo n.º 1
0
 def __call__(self, *args, warn=True, **kwargs):
     if warn:
         get_logger(self.__class__.__name__).warning(
             "Do not call the instance directly. Use instead methods P(z, k) or "
             "logP(z, k) to get the (log)power spectrum. (If you know what you are "
             "doing, pass warn=False)")
     return super().__call__(*args, **kwargs)
Exemplo n.º 2
0
 def check_ranges(self, z, k):
     """Checks that we are not trying to extrapolate beyond the interpolator limits."""
     z = np.atleast_1d(z).flatten()
     min_z, max_z = min(z), max(z)
     if min_z < self.zmin and not np.allclose(min_z, self.zmin):
         raise LoggedError(
             get_logger(self.__class__.__name__),
             f"Not possible to extrapolate to z={min(z)} "
             f"(minimum z computed is {self.zmin}).")
     if max_z > self.zmax and not np.allclose(max_z, self.zmax):
         raise LoggedError(
             get_logger(self.__class__.__name__),
             f"Not possible to extrapolate to z={max(z)} "
             f"(maximum z computed is {self.zmax}).")
     k = np.atleast_1d(k).flatten()
     min_k, max_k = min(k), max(k)
     if min_k < self.kmin and not np.allclose(min_k, self.kmin):
         raise LoggedError(
             get_logger(self.__class__.__name__),
             f"Not possible to extrapolate to k={min(k)} 1/Mpc "
             f"(minimum k possible is {self.kmin} 1/Mpc).")
     if max_k > self.kmax and not np.allclose(max_k, self.kmax):
         raise LoggedError(
             get_logger(self.__class__.__name__),
             f"Not possible to extrapolate to k={max(k)} 1/Mpc "
             f"(maximum k possible is {self.kmax} 1/Mpc).")
Exemplo n.º 3
0
 def install(cls, path=None, data=True, no_progress_bars=False, **_kwargs):
     if not data:
         return True
     log = get_logger(cls.get_qualified_class_name())
     opts = cls.get_install_options()
     if not opts:
         log.info("No install options. Nothing to do.")
         return True
     repo = opts.get("github_repository", None)
     if repo:
         from cobaya.install import download_github_release
         log.info("Downloading %s data..." % repo)
         return download_github_release(
             os.path.join(path, "data"), repo, opts.get("github_release", "master"),
             no_progress_bars=no_progress_bars, logger=log)
     else:
         full_path = cls.get_path(path)
         if not os.path.exists(full_path):
             os.makedirs(full_path)
         if not data:
             return True
         url = opts["download_url"]
         log.info("Downloading likelihood data file: %s...", url)
         from cobaya.install import download_file
         if not download_file(url, full_path, decompress=True, logger=log,
                              no_progress_bars=no_progress_bars):
             return False
         log.info("Likelihood data downloaded and uncompressed correctly.")
         return True
Exemplo n.º 4
0
 def install(cls, path=None, force=False, code=False, data=False,
             no_progress_bars=False):
     if not code:
         return True
     log = get_logger(__name__)
     log.info("Downloading PolyChord...")
     success = download_github_release(os.path.join(path, "code"), cls._pc_repo_name,
                                       cls._pc_repo_version,
                                       no_progress_bars=no_progress_bars,
                                       logger=log)
     if not success:
         log.error("Could not download PolyChord.")
         return False
     log.info("Compiling (Py)PolyChord...")
     from subprocess import Popen, PIPE
     # Needs to re-define os' PWD,
     # because MakeFile calls it and is not affected by the cwd of Popen
     cwd = os.path.join(path, "code",
                        cls._pc_repo_name[cls._pc_repo_name.find("/") + 1:])
     my_env = os.environ.copy()
     my_env.update({"PWD": cwd})
     if "CC" not in my_env:
         my_env["CC"] = "mpicc"
     if "CXX" not in my_env:
         my_env["CXX"] = "mpicxx"
     process_make = Popen([sys.executable, "setup.py", "build"],
                          cwd=cwd, env=my_env, stdout=PIPE, stderr=PIPE)
     out, err = process_make.communicate()
     if process_make.returncode:
         log.info(out.decode("utf-8"))
         log.info(err.decode("utf-8"))
         log.error("Python build failed!")
         return False
     return True
Exemplo n.º 5
0
 def set_lock(self, log, filename, force=False):
     if self.has_lock():
         return
     self.lock_file = filename + '.locked'
     self.lock_error_file = filename + '.lock_err'
     try:
         os.remove(self.lock_error_file)
     except OSError:
         pass
     self.log = log or get_logger("file_lock")
     try:
         h: Any = None
         if use_portalocker():
             import portalocker
             try:
                 h = open(self.lock_file, 'wb')
                 portalocker.lock(h, portalocker.LOCK_EX + portalocker.LOCK_NB)
                 self._file_handle = h
             except portalocker.exceptions.BaseLockException:
                 if h:
                     h.close()
                 self.lock_error()
         else:
             # will work, but crashes will leave .lock files that will raise error
             self._file_handle = open(self.lock_file, 'wb' if force else 'xb')
     except OSError:
         self.lock_error()
Exemplo n.º 6
0
def bib_script(args=None):
    """Command line script for the bibliography."""
    warn_deprecation()
    # Parse arguments and launch
    import argparse
    parser = argparse.ArgumentParser(
        prog="cobaya bib",
        description=
        ("Prints bibliography to be cited for one or more components or input files."
         ))
    parser.add_argument(
        "files_or_components",
        action="store",
        nargs="+",
        metavar="input_file.yaml|component_name",
        help="Component(s) or input file(s) whose bib info is requested.")
    arguments = parser.parse_args(args)
    # Configure the logger ASAP
    logger_setup()
    logger = get_logger("bib")
    # Gather requests
    infos: List[Union[Dict, str]] = []
    for f in arguments.files_or_components:
        if os.path.splitext(f)[1].lower() in Extension.yamls:
            infos += [load_input(f)]
        else:  # a single component name, no kind specified
            infos += [f]
    if not infos:
        logger.info(
            "Nothing to do. Pass input files or component names as arguments.")
        return
    print(pretty_repr_bib(*get_bib_info(*infos, logger=logger)))
Exemplo n.º 7
0
 def is_installed(cls, **kwargs):
     """
     Performs an installation check and returns ``True`` if successful, ``False`` if
     not, or raises :class:`tools.VersionCheckError` if there is an obsolete
     installation.
     """
     if kwargs.get("data", True):
         path = kwargs["path"]
         path = cls.get_path(path)  # ensure full install path passed
         opts = cls.get_install_options()
         if not opts:
             return True
         elif not (os.path.exists(path) and len(os.listdir(path)) > 0):
             log = get_logger(cls.get_qualified_class_name())
             log.error("The given installation path does not exist: '%s'",
                       path)
             return False
         elif opts.get("github_release"):
             try:
                 with open(os.path.join(path, _version_filename), "r") as f:
                     installed_version = version.parse(f.readlines()[0])
             except FileNotFoundError:  # old install: no version file
                 raise VersionCheckError("Could not read current version.")
             min_version = version.parse(opts.get("github_release"))
             if installed_version < min_version:
                 raise VersionCheckError(
                     f"Installed version ({installed_version}) "
                     f"older than minimum required one ({min_version}).")
     return True
Exemplo n.º 8
0
def download_github_release(directory,
                            repo_name,
                            release_name,
                            repo_rename=None,
                            no_progress_bars=False,
                            logger=None):
    logger = logger or get_logger("install")
    if "/" in repo_name:
        github_user = repo_name[:repo_name.find("/")]
        repo_name = repo_name[repo_name.find("/") + 1:]
    else:
        github_user = "******"
    if not os.path.exists(directory):
        os.makedirs(directory)
    url = (r"https://github.com/" + github_user + "/" + repo_name +
           "/archive/" + release_name + ".tar.gz")
    if not download_file(url,
                         directory,
                         decompress=True,
                         no_progress_bars=no_progress_bars,
                         logger=logger):
        return False
    # Remove version number from directory name
    w_version = next(d for d in os.listdir(directory)
                     if (d.startswith(repo_name) and len(d) != len(repo_name)))
    repo_rename = repo_rename or repo_name
    repo_path = os.path.join(directory, repo_rename)
    if os.path.exists(repo_path):
        shutil.rmtree(repo_path)
    os.rename(os.path.join(directory, w_version), repo_path)
    logger.info("%s %s downloaded and decompressed correctly.", repo_name,
                release_name)
    return True
Exemplo n.º 9
0
def install_clik(path, no_progress_bars=False):
    log = get_logger("clik")
    log.info("Installing pre-requisites...")
    for req in ("cython", "astropy"):
        exit_status = pip_install(req)
        if exit_status:
            raise LoggedError(log, "Failed installing '%s'.", req)
    log.info("Downloading...")
    click_url = pla_url_prefix + '152000'
    if not download_file(click_url,
                         path,
                         decompress=True,
                         no_progress_bars=no_progress_bars,
                         logger=log):
        log.error("Not possible to download clik.")
        return False
    source_dir = get_clik_source_folder(path)
    log.info('Installing from directory %s' % source_dir)
    cwd = os.getcwd()
    try:
        os.chdir(source_dir)
        log.info("Configuring... (and maybe installing dependencies...)")
        flags = ["--install_all_deps", "--extra_lib=m"
                 ]  # missing for some reason in some systems, but harmless
        if not execute([sys.executable, "waf", "configure"] + flags):
            log.error("Configuration failed!")
            return False
        log.info("Compiling...")
        if not execute([sys.executable, "waf", "install"]):
            log.error("Compilation failed!")
            return False
    finally:
        os.chdir(cwd)
    log.info("Finished!")
    return True
Exemplo n.º 10
0
def download_file(url,
                  path,
                  no_progress_bars=False,
                  decompress=False,
                  logger=None):
    logger = logger or get_logger("install")
    with tempfile.TemporaryDirectory() as tmp_path:
        try:
            req = requests.get(url, allow_redirects=True, stream=True)
            # get hinted filename if available:
            try:
                filename = re.findall("filename=(.+)",
                                      req.headers['content-disposition'])[0]
                filename = filename.strip('"\'')
            except KeyError:
                filename = os.path.basename(url)
            filename_tmp_path = os.path.normpath(
                os.path.join(tmp_path, filename))
            size = int(req.headers.get('content-length', 0))
            # Adapted from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
            if not no_progress_bars:
                bar = tqdm.tqdm(total=size,
                                unit='iB',
                                unit_scale=True,
                                unit_divisor=1024)
            with open(filename_tmp_path, 'wb') as f:
                for data in req.iter_content(chunk_size=1024):
                    chunk_size = f.write(data)
                    if not no_progress_bars:
                        bar.update(chunk_size)
            if not no_progress_bars:
                bar.close()
            logger.info('Downloaded filename %s', filename)
        except Exception as e:
            logger.error("Error downloading %s' to folder '%s': %s", url,
                         tmp_path, e)
            return False
        logger.debug('Got: %s', filename)
        if not decompress:
            return True
        extension = os.path.splitext(filename)[-1][1:]
        try:
            if extension == "zip":
                from zipfile import ZipFile
                with ZipFile(filename_tmp_path, 'r') as zipObj:
                    zipObj.extractall(path)
            else:
                import tarfile
                if extension == "tgz":
                    extension = "gz"
                with tarfile.open(filename_tmp_path, "r:" + extension) as tar:
                    tar.extractall(path)
            logger.debug('Decompressed: %s', filename)
            return True
        except Exception as excpt:
            logger.error(
                "Error decompressing downloaded file! Corrupt file? [%s]",
                excpt)
            return False
Exemplo n.º 11
0
def is_installed_clik(path, reload=False):
    # min_version here is checked inside get_clik_import_path, since it is displayed
    # in the folder name and cannot be retrieved from the module.
    try:
        return bool(load_clik(
            "clik", path=path, get_import_path=get_clik_import_path,
            reload=reload, logger=get_logger("clik"), not_installed_level="debug"))
    except ComponentNotInstalledError:
        return False
Exemplo n.º 12
0
 def __exit__(self, exc_type, exc_val, exc_tb):
     if log:
         log.info('END %s %s', self.name, self.tag)
     if exc_type:
         self.set(State.ERROR)
         if not self.wait_all_ended(
                 timeout=not issubclass(exc_type, OtherProcessError)):
             from cobaya.log import get_traceback_text, LoggedError, get_logger
             get_logger(self.name).critical(
                 "Aborting MPI due to error" if issubclass(exc_type, LoggedError) else
                 get_traceback_text(sys.exc_info()))
             self.timeout_abort_proc()
             self.wait_all_ended()  # if didn't actually MPI abort
     else:
         self.set(State.END)
         self.wait_all_ended()
     set_current_process_state(self.last_process_state)
     if not exc_type and any(self.states == State.ERROR):
         self.fire_error()
Exemplo n.º 13
0
 def is_installed(cls, **kwargs):
     log = get_logger(cls.__name__)
     if not kwargs.get("code", True):
         return True
     check = kwargs.get("check", True)
     func = log.info if check else log.error
     path: Optional[str] = kwargs["path"]
     if path is not None and path.lower() == "global":
         path = None
     if path and not kwargs.get("allow_global"):
         if is_main_process():
             log.info("Importing *local* PolyChord from '%s'.", path)
         if not os.path.exists(path):
             if is_main_process():
                 func("The given folder does not exist: '%s'", path)
             return False
         poly_build_path = cls.get_import_path(path)
         if not poly_build_path:
             return False
     elif not path:
         if is_main_process():
             log.info("Importing *global* PolyChord.")
         poly_build_path = None
     else:
         if is_main_process():
             log.info(
                 "Importing *auto-installed* PolyChord (but defaulting to *global*)."
             )
         poly_build_path = cls.get_import_path(path)
     cls._poly_build_path = poly_build_path
     try:
         # TODO: add min_version when polychord module version available
         return load_module('pypolychord',
                            path=poly_build_path,
                            min_version=None)
     except ModuleNotFoundError:
         if path is not None and path.lower() != "global":
             log.error(
                 "Couldn't find the PolyChord python interface at '%s'. "
                 "Are you sure it has been installed there?", path)
         elif not check:
             log.error(
                 "Could not import global PolyChord installation. "
                 "Specify a Cobaya or PolyChord installation path, "
                 "or install the PolyChord Python interface globally with "
                 "'cd /path/to/polychord/ ; python setup.py install'")
         return False
     except ImportError as e:
         log.error(
             "Couldn't load the PolyChord python interface in %s:\n"
             "%s", poly_build_path or "global", e)
         return False
     except VersionCheckError as e:
         log.error(str(e))
         return False
Exemplo n.º 14
0
 def is_installed(cls, reload=False, **kwargs):
     if not kwargs.get("code", True):
         return True
     try:
         return bool(load_external_module(
             "pypolychord", path=kwargs["path"],
             get_import_path=get_compiled_import_path,
             min_version=cls._pc_repo_version, reload=reload,
             logger=get_logger(cls.__name__), not_installed_level="debug"))
     except ComponentNotInstalledError:
         return False
Exemplo n.º 15
0
 def is_installed(cls, **kwargs):
     if not kwargs.get("code", True):
         return True
     log = get_logger(cls.__name__)
     import platform
     check = kwargs.get("check", True)
     func = log.info if check else log.error
     path = kwargs["path"]
     if path is not None and path.lower() == "global":
         path = None
     if isinstance(path, str) and not kwargs.get("allow_global"):
         log.info("Importing *local* CAMB from " + path)
         if not os.path.exists(path):
             func("The given folder does not exist: '%s'", path)
             return False
         if not os.path.exists(os.path.join(path, "setup.py")):
             func(
                 "Either CAMB is not in the given folder, '%s', or you are using"
                 " a very old version without the Python interface.", path)
             return False
         if not os.path.isfile(
                 os.path.realpath(
                     os.path.join(
                         path, "camb", "cambdll.dll" if
                         (platform.system()
                          == "Windows") else "camblib.so"))):
             log.error(
                 "CAMB installation at '%s' appears not to be compiled.",
                 path)
             return False
     elif not path:
         log.info("Importing *global* CAMB.")
         path = None
     else:
         log.info(
             "Importing *auto-installed* CAMB (but defaulting to *global*)."
         )
     try:
         return load_module("camb",
                            path=path,
                            min_version=cls._min_camb_version)
     except ImportError:
         if path is not None and path.lower() != "global":
             func(
                 "Couldn't find the CAMB python interface at '%s'. "
                 "Are you sure it has been installed there?", path)
         elif not check:
             log.error("Could not import global CAMB installation. "
                       "Specify a Cobaya or CAMB installation path, "
                       "or install the 'camb' Python package globally.")
         return False
     except VersionCheckError as e:
         log.error(str(e))
         return False
Exemplo n.º 16
0
 def install(cls, path=None, force=False, code=True, data=True,
             no_progress_bars=False):
     name = cls.get_qualified_class_name()
     log = get_logger(name)
     path_names = {"code": common_path, "data": get_data_path(name)}
     import platform
     if platform.system() == "Windows":
         log.error("Not compatible with Windows.")
         return False
     global _clik_install_failed
     if _clik_install_failed:
         log.info("Previous clik install failed, skipping")
         return False
     # Create common folders: all planck likelihoods share install
     # folder for code and data
     paths = {}
     for s in ("code", "data"):
         if eval(s):
             paths[s] = os.path.realpath(os.path.join(path, s, path_names[s]))
             if not os.path.exists(paths[s]):
                 os.makedirs(paths[s])
     success = True
     # Install clik
     if code and (not is_installed_clik(paths["code"]) or force):
         log.info("Installing the clik code.")
         success *= install_clik(paths["code"], no_progress_bars=no_progress_bars)
         if not success:
             log.warning("clik code installation failed! "
                         "Try configuring+compiling by hand at " + paths["code"])
             _clik_install_failed = True
     if data:
         # 2nd test, in case the code wasn't there but the data is:
         if force or not cls.is_installed(path=path, code=False, data=True):
             log.info("Downloading the likelihood data.")
             product_id, _ = get_product_id_and_clik_file(name)
             # Download and decompress the particular likelihood
             url = pla_url_prefix + product_id
             # Helper for the progress bars: some known product download sizes
             # (no actual effect if missing or wrong!)
             size = {"1900": 314153370, "1903": 4509715660, "151902": 60293120,
                     "151905": 5476083302, "151903": 8160437862}.get(product_id)
             if not download_file(url, paths["data"], size=size, decompress=True,
                                  logger=log, no_progress_bars=no_progress_bars):
                 log.error("Not possible to download this likelihood.")
                 success = False
             # Additional data and covmats, stored in same repo as the
             # 2018 python lensing likelihood
             from cobaya.likelihoods.planck_2018_lensing import native
             if not native.is_installed(data=True, path=path):
                 success *= native.install(path=path, force=force, code=code,
                                           data=data,
                                           no_progress_bars=no_progress_bars)
     return success
Exemplo n.º 17
0
 def is_installed(cls, **kwargs):
     if kwargs.get("data", True):
         path = kwargs["path"]
         opts = cls.get_install_options()
         if not opts:
             return True
         elif not (os.path.exists(path) and len(os.listdir(path)) > 0):
             log = get_logger(cls.get_qualified_class_name())
             func = log.info if kwargs.get("check", True) else log.error
             func("The given installation path does not exist: '%s'", path)
             return False
     return True
Exemplo n.º 18
0
def get_bib_info(*infos, logger=None):
    """
    Gathers and returns the descriptions and bibliographic sources for the components
    mentioned in ``infos``.

    ``infos`` can be input dictionaries or single component names.
    """
    if not logger:
        logger_setup()
        logger = get_logger("bib")
    used_components, component_infos = get_used_components(*infos,
                                                           return_infos=True)
    descs: InfoDict = {}
    bibs: InfoDict = {}
    used_components = get_used_components(*infos)
    for kind, components in used_components.items():
        if kind is None:
            continue  # we will deal with bare component names later, to avoid repetition
        descs[kind], bibs[kind] = {}, {}
        for component in components:
            try:
                descs[kind][component] = get_desc_component(
                    component, kind, component_infos[component])
                bibs[kind][component] = get_bib_component(component, kind)
            except ComponentNotFoundError:
                sugg = similar_internal_class_names(component)
                logger.error(
                    f"Could not identify component '{component}'. "
                    f"Did you mean any of the following? {sugg} (mind capitalization!)"
                )
                continue
    # Deal with bare component names
    for component in used_components.get(None, []):
        try:
            cls = get_component_class(component)
        except ComponentNotFoundError:
            sugg = similar_internal_class_names(component)
            logger.error(
                f"Could not identify component '{component}'. "
                f"Did you mean any of the following? {sugg} (mind capitalization!)"
            )
            continue
        kind = cls.get_kind()
        if kind not in descs:
            descs[kind], bibs[kind] = {}, {}
        if kind in descs and component in descs[kind]:
            continue  # avoid repetition
        descs[kind][component] = get_desc_component(cls, kind)
        bibs[kind][component] = get_bib_component(cls, kind)
    descs["cobaya"] = {"cobaya": cobaya_desc}
    bibs["cobaya"] = {"cobaya": cobaya_bib}
    return descs, bibs
Exemplo n.º 19
0
 def install(cls, path=None, code=True, no_progress_bars=False, **_kwargs):
     log = get_logger(cls.__name__)
     if not code:
         log.info("Code not requested. Nothing to do.")
         return True
     log.info("Installing pre-requisites...")
     exit_status = pip_install("cython")
     if exit_status:
         log.error("Could not install pre-requisite: cython")
         return False
     log.info("Downloading classy...")
     success = download_github_release(os.path.join(path, "code"),
                                       cls._classy_repo_name,
                                       cls._classy_repo_version,
                                       directory=cls.__name__,
                                       no_progress_bars=no_progress_bars,
                                       logger=log)
     if not success:
         log.error("Could not download classy.")
         return False
     # Compilation
     # gcc check after downloading, in case the user wants to change the compiler by
     # hand in the Makefile
     classy_path = cls.get_path(path)
     if not check_gcc_version(cls._classy_min_gcc_version,
                              error_returns=False):
         log.error(
             "Your gcc version is too low! CLASS would probably compile, "
             "but it would leak memory when running a chain. Please use a "
             "gcc version newer than %s. You can still compile CLASS by hand, "
             "maybe changing the compiler in the Makefile. CLASS has been "
             "downloaded into %r", cls._classy_min_gcc_version, classy_path)
         return False
     log.info("Compiling classy...")
     from subprocess import Popen, PIPE
     env = deepcopy(os.environ)
     env.update({"PYTHON": sys.executable})
     process_make = Popen(["make"],
                          cwd=classy_path,
                          stdout=PIPE,
                          stderr=PIPE,
                          env=env)
     out, err = process_make.communicate()
     if process_make.returncode:
         log.info(out)
         log.info(err)
         log.error("Compilation failed!")
         return False
     return True
Exemplo n.º 20
0
 def is_installed(cls, **kwargs):
     if not kwargs.get("code", True):
         return True
     log = get_logger(cls.__name__)
     check = kwargs.get("check", True)
     func = log.info if check else log.error
     path = kwargs["path"]
     if path is not None and path.lower() == "global":
         path = None
     if path and not kwargs.get("allow_global"):
         log.info("Importing *local* CLASS from '%s'.", path)
         assert path is not None
         if not os.path.exists(path):
             func("The given folder does not exist: '%s'", path)
             return False
         classy_build_path = cls.get_import_path(path)
         if not classy_build_path:
             return False
     elif not path:
         log.info("Importing *global* CLASS.")
         classy_build_path = None
     else:
         log.info(
             "Importing *auto-installed* CLASS (but defaulting to *global*)."
         )
         classy_build_path = cls.get_import_path(path)
     try:
         return load_module('classy',
                            path=classy_build_path,
                            min_version=cls._classy_repo_version)
     except ImportError:
         if path is not None and path.lower() != "global":
             func(
                 "Couldn't find the CLASS python interface at '%s'. "
                 "Are you sure it has been installed there?", path)
         elif not check:
             log.error(
                 "Could not import global CLASS installation. "
                 "Specify a Cobaya or CLASS installation path, "
                 "or install the CLASS Python interface globally with "
                 "'cd /path/to/class/python/ ; python setup.py install'")
         return False
     except VersionCheckError as e:
         log.error(str(e))
         return False
Exemplo n.º 21
0
 def get_import_path(cls, path):
     log = get_logger(cls.__name__)
     poly_build_path = os.path.join(path, "build")
     if not os.path.isdir(poly_build_path):
         log.error(
             "Either PolyChord is not in the given folder, "
             "'%s', or you have not compiled it.", path)
         return None
     py_version = "%d.%d" % (sys.version_info.major, sys.version_info.minor)
     try:
         post = next(d for d in os.listdir(poly_build_path)
                     if (d.startswith("lib.") and py_version in d))
     except StopIteration:
         log.error(
             "The PolyChord installation at '%s' has not been compiled for the "
             "current Python version.", path)
         return None
     return os.path.join(poly_build_path, post)
Exemplo n.º 22
0
def get_sampler(info_sampler: SamplersDict,
                model: Model,
                output: Optional[Output] = None,
                packages_path: Optional[str] = None) -> 'Sampler':
    assert isinstance(info_sampler, Mapping), (
        "The first argument must be a dictionary with the info needed for the sampler. "
        "If you were trying to pass the name of an input file instead, "
        "load it first with 'cobaya.input.load_input', "
        "or, if you were passing a yaml string, load it with 'cobaya.yaml.yaml_load'."
    )
    logger_sampler = get_logger(__name__)
    info_sampler = deepcopy_where_possible(info_sampler)
    if output is None:
        output = OutputDummy()
    # Check and update info
    check_sane_info_sampler(info_sampler)
    updated_info_sampler = update_info({"sampler": info_sampler
                                        })["sampler"]  # type: ignore
    if is_debug(logger_sampler):
        logger_sampler.debug(
            "Input info updated with defaults (dumped to YAML):\n%s",
            yaml_dump(updated_info_sampler))
    # Get sampler class & check resume/force compatibility
    sampler_name, sampler_class = get_sampler_name_and_class(
        updated_info_sampler, logger=logger_sampler)
    check_sampler_info((output.reload_updated_info(use_cache=True)
                        or {}).get("sampler"),
                       updated_info_sampler,
                       is_resuming=output.is_resuming())
    # Check if resumable run
    sampler_class.check_force_resume(output,
                                     info=updated_info_sampler[sampler_name])
    # Instantiate the sampler
    sampler_instance = sampler_class(updated_info_sampler[sampler_name],
                                     model,
                                     output,
                                     packages_path=packages_path)
    # If output, dump updated
    if output:
        to_dump = model.info()
        to_dump["sampler"] = {sampler_name: sampler_instance.info()}
        to_dump["output"] = os.path.join(output.folder, output.prefix)
        output.check_and_dump_info(None, to_dump, check_compatible=False)
    return sampler_instance
Exemplo n.º 23
0
def check_sampler_info(info_old: Optional[SamplersDict],
                       info_new: SamplersDict,
                       is_resuming=False):
    """
    Checks compatibility between the new sampler info and that of a pre-existing run.

    Done separately from `Output.check_compatible_and_dump` because there may be
    multiple samplers mentioned in an `updated.yaml` file, e.g. `MCMC` + `Minimize`.
    """
    logger_sampler = get_logger(__name__)
    if not info_old:
        return
    # TODO: restore this at some point: just append minimize info to the old one
    # There is old info, but the new one is Minimizer and the old one is not
    # if (len(info_old) == 1 and list(info_old) != ["minimize"] and
    #      list(info_new) == ["minimize"]):
    #     # In-place append of old+new --> new
    #     aux = info_new.pop("minimize")
    #     info_new.update(info_old)
    #     info_new.update({"minimize": aux})
    #     info_old = {}
    #     keep_old = {}
    if list(info_old) != list(info_new) and list(info_new) == ["minimize"]:
        return
    if list(info_old) == list(info_new):
        # Restore some selected old values for some classes
        keep_old = get_preferred_old_values({"sampler": info_old})
        info_new = recursive_update(info_new, keep_old.get("sampler", {}))
    if not is_equal_info({"sampler": info_old}, {"sampler": info_new},
                         strict=False):
        if is_resuming:
            raise LoggedError(
                logger_sampler,
                "Old and new Sampler information not compatible! "
                "Resuming not possible!")
        else:
            raise LoggedError(
                logger_sampler,
                "Found old Sampler information which is not compatible "
                "with the new one. Delete the previous output manually, "
                "or automatically with either "
                "'-f', '--force', 'force: True'")
Exemplo n.º 24
0
 def __init__(self, *args, **kwargs):
     # Ensure check for install and version errors
     # (e.g. may inherit from a class that inherits from this one, and not have them)
     if self.install_options:
         name = self.get_qualified_class_name()
         logger = get_logger(name)
         packages_path = kwargs.get(
             "packages_path") or resolve_packages_path()
         old = False
         try:
             installed = self.is_installed(path=packages_path)
         except Exception as excpt:  # catches VersionCheckError and unexpected ones
             installed = False
             old = isinstance(excpt, VersionCheckError)
             logger.error(f"{type(excpt).__name__}: {excpt}")
         if not installed:
             not_or_old = ("is not up to date"
                           if old else "has not been correctly installed")
             raise ComponentNotInstalledError(logger, (
                 f"The data for this likelihood {not_or_old}. To install it, "
                 f"run `cobaya-install {name}{' --upgrade' if old else ''}`"
             ))
     super().__init__(*args, **kwargs)
Exemplo n.º 25
0
 def install(cls, path=None, code=True, no_progress_bars=False, **_kwargs):
     log = get_logger(cls.__name__)
     if not code:
         log.info("Code not requested. Nothing to do.")
         return True
     log.info("Downloading camb...")
     success = download_github_release(os.path.join(path, "code"),
                                       cls._camb_repo_name,
                                       cls._camb_repo_version,
                                       no_progress_bars=no_progress_bars,
                                       logger=log)
     if not success:
         log.error("Could not download camb.")
         return False
     camb_path = cls.get_path(path)
     log.info("Compiling camb...")
     from subprocess import Popen, PIPE
     process_make = Popen([sys.executable, "setup.py", "build_cluster"],
                          cwd=camb_path,
                          stdout=PIPE,
                          stderr=PIPE)
     out, err = process_make.communicate()
     if process_make.returncode:
         log.info(out.decode())
         log.info(err.decode())
         gcc_check = check_gcc_version(cls._camb_min_gcc_version,
                                       error_returns=False)
         if not gcc_check:
             cause = (
                 " Possible cause: it looks like `gcc` does not have the correct "
                 "version number (CAMB requires %s); and `ifort` is also "
                 "probably not available." % cls._camb_min_gcc_version)
         else:
             cause = ""
         log.error("Compilation failed!" + cause)
         return False
     return True
Exemplo n.º 26
0
def is_installed_clik(path, allow_global=False, check=True):
    log = get_logger("clik")
    func = log.info if check else log.error
    if path is not None and path.lower() == "global":
        path = None
    clik_path = None
    if isinstance(path, str) and path.lower() != "global":
        try:
            clik_path = os.path.join(get_clik_source_folder(path),
                                     'lib/python/site-packages')
        except FileNotFoundError:
            func("The given folder does not exist: '%s'", clik_path or path)
            return False
    if path and not allow_global:
        log.info("Importing *local* clik from %s ", path)
    elif not path:
        log.info("Importing *global* clik.")
    else:
        log.info(
            "Importing *auto-installed* clik (but defaulting to *global*).")
    try:
        return load_module("clik", path=clik_path)
    except ImportError:
        if path is not None and path.lower() != "global":
            func(
                "Couldn't find the clik python interface at '%s'. "
                "Are you sure it has been installed and compiled there?", path)
        elif not check:
            log.error("Could not import global clik installation. "
                      "Specify a Cobaya or clik installation path, "
                      "or install the clik Python interface globally.")
    except Exception as excpt:
        log.error(
            "Error when trying to import clik from %s [%s]. Error message: [%s].",
            path, clik_path, str(excpt))
        return False
Exemplo n.º 27
0
def post(info_or_yaml_or_file: Union[InputDict, str, os.PathLike],
         sample: Union[SampleCollection, List[SampleCollection], None] = None
         ) -> PostTuple:
    info = load_input_dict(info_or_yaml_or_file)
    logger_setup(info.get("debug"), info.get("debug_file"))
    log = get_logger(__name__)
    # MARKED FOR DEPRECATION IN v3.0
    if info.get("modules"):
        raise LoggedError(log, "The input field 'modules' has been deprecated."
                               "Please use instead %r", packages_path_input)
    # END OF DEPRECATION BLOCK
    info_post: PostDict = info.get("post") or {}
    if not info_post:
        raise LoggedError(log, "No 'post' block given. Nothing to do!")
    if mpi.is_main_process() and info.get("resume"):
        log.warning("Resuming not implemented for post-processing. Re-starting.")
    if not info.get("output") and info_post.get("output") \
            and not info.get("params"):
        raise LoggedError(log, "The input dictionary must have be a full option "
                               "dictionary, or have an existing 'output' root to load "
                               "previous settings from ('output' to read from is in the "
                               "main block not under 'post'). ")
    # 1. Load existing sample
    output_in = get_output(prefix=info.get("output"))
    if output_in:
        info_in = output_in.load_updated_info() or update_info(info)
    else:
        info_in = update_info(info)
    params_in: ExpandedParamsDict = info_in["params"]  # type: ignore
    dummy_model_in = DummyModel(params_in, info_in.get("likelihood", {}),
                                info_in.get("prior"))

    in_collections = []
    thin = info_post.get("thin", 1)
    skip = info_post.get("skip", 0)
    if info.get('thin') is not None or info.get('skip') is not None:  # type: ignore
        raise LoggedError(log, "'thin' and 'skip' should be "
                               "parameters of the 'post' block")

    if sample:
        # If MPI, assume for each MPI process post is passed in the list of
        # collections that should be processed by that process
        # (e.g. single chain output from sampler)
        if isinstance(sample, SampleCollection):
            in_collections = [sample]
        else:
            in_collections = sample
        for i, collection in enumerate(in_collections):
            if skip:
                if 0 < skip < 1:
                    skip = int(round(skip * len(collection)))
                collection = collection.filtered_copy(slice(skip, None))
            if thin != 1:
                collection = collection.thin_samples(thin)
            in_collections[i] = collection
    elif output_in:
        files = output_in.find_collections()
        numbered = files
        if not numbered:
            # look for un-numbered output files
            files = output_in.find_collections(name=False)
        if files:
            if mpi.size() > len(files):
                raise LoggedError(log, "Number of MPI processes (%s) is larger than "
                                       "the number of sample files (%s)",
                                  mpi.size(), len(files))
            for num in range(mpi.rank(), len(files), mpi.size()):
                in_collections += [SampleCollection(
                    dummy_model_in, output_in,
                    onload_thin=thin, onload_skip=skip, load=True, file_name=files[num],
                    name=str(num + 1) if numbered else "")]
        else:
            raise LoggedError(log, "No samples found for the input model with prefix %s",
                              os.path.join(output_in.folder, output_in.prefix))

    else:
        raise LoggedError(log, "No output from where to load from, "
                               "nor input collections given.")
    if any(len(c) <= 1 for c in in_collections):
        raise LoggedError(
            log, "Not enough samples for post-processing. Try using a larger sample, "
                 "or skipping or thinning less.")
    mpi.sync_processes()
    log.info("Will process %d sample points.", sum(len(c) for c in in_collections))

    # 2. Compare old and new info: determine what to do
    add = info_post.get("add") or {}
    if "remove" in add:
        raise LoggedError(log, "remove block should be under 'post', not 'add'")
    remove = info_post.get("remove") or {}
    # Add a dummy 'one' likelihood, to absorb unused parameters
    if not add.get("likelihood"):
        add["likelihood"] = {}
    add["likelihood"]["one"] = None
    # Expand the "add" info, but don't add new default sampled parameters
    orig_params = set(add.get("params") or [])
    add = update_info(add, add_aggr_chi2=False)
    add_params: ExpandedParamsDict = add["params"]  # type: ignore
    for p in set(add_params) - orig_params:
        if p in params_in:
            add_params.pop(p)

    # 2.1 Adding/removing derived parameters and changes in priors of sampled parameters
    out_combined_params = deepcopy_where_possible(params_in)
    remove_params = list(str_to_list(remove.get("params")) or [])
    for p in remove_params:
        pinfo = params_in.get(p)
        if pinfo is None or not is_derived_param(pinfo):
            raise LoggedError(
                log,
                "You tried to remove parameter '%s', which is not a derived parameter. "
                "Only derived parameters can be removed during post-processing.", p)
        out_combined_params.pop(p)
    # Force recomputation of aggregated chi2
    for p in list(out_combined_params):
        if p.startswith(get_chi2_name("")):
            out_combined_params.pop(p)
    prior_recompute_1d = False
    for p, pinfo in add_params.items():
        pinfo_in = params_in.get(p)
        if is_sampled_param(pinfo):
            if not is_sampled_param(pinfo_in):
                # No added sampled parameters (de-marginalisation not implemented)
                if pinfo_in is None:
                    raise LoggedError(
                        log, "You added a new sampled parameter %r (maybe accidentally "
                             "by adding a new likelihood that depends on it). "
                             "Adding new sampled parameters is not possible. Try fixing "
                             "it to some value.", p)
                else:
                    raise LoggedError(
                        log,
                        "You tried to change the prior of parameter '%s', "
                        "but it was not a sampled parameter. "
                        "To change that prior, you need to define as an external one.", p)
            # recompute prior if potentially changed sampled parameter priors
            prior_recompute_1d = True
        elif is_derived_param(pinfo):
            if p in out_combined_params:
                raise LoggedError(
                    log, "You tried to add derived parameter '%s', which is already "
                         "present. To force its recomputation, 'remove' it too.", p)
        elif is_fixed_or_function_param(pinfo):
            # Only one possibility left "fixed" parameter that was not present before:
            # input of new likelihood, or just an argument for dynamical derived (dropped)
            if pinfo_in and p in params_in and pinfo["value"] != pinfo_in.get("value"):
                raise LoggedError(
                    log,
                    "You tried to add a fixed parameter '%s: %r' that was already present"
                    " but had a different value or was not fixed. This is not allowed. "
                    "The old info of the parameter was '%s: %r'",
                    p, dict(pinfo), p, dict(pinfo_in))
        elif not pinfo_in:  # OK as long as we have known value for it
            raise LoggedError(log, "Parameter %s no known value. ", p)
        out_combined_params[p] = pinfo

    out_combined: InputDict = {"params": out_combined_params}  # type: ignore
    # Turn the rest of *derived* parameters into constants,
    # so that the likelihoods do not try to recompute them
    # But be careful to exclude *input* params that have a "derived: True" value
    # (which in "updated info" turns into "derived: 'lambda [x]: [x]'")
    # Don't assign to derived parameters to theories, only likelihoods, so they can be
    # recomputed if needed. If the theory does not need to be computed, it doesn't matter
    # if it is already assigned parameters in the usual way; likelihoods can get
    # the required derived parameters from the stored sample derived parameter inputs.
    out_params_with_computed = deepcopy_where_possible(out_combined_params)

    dropped_theory = set()
    for p, pinfo in out_params_with_computed.items():
        if (is_derived_param(pinfo) and "value" not in pinfo
                and p not in add_params):
            out_params_with_computed[p] = {"value": np.nan}
            dropped_theory.add(p)
    # 2.2 Manage adding/removing priors and likelihoods
    warn_remove = False
    kind: ModelBlock
    for kind in ("prior", "likelihood", "theory"):
        out_combined[kind] = deepcopy_where_possible(info_in.get(kind)) or {}
        for remove_item in str_to_list(remove.get(kind)) or []:
            try:
                out_combined[kind].pop(remove_item, None)
                if remove_item not in (add.get(kind) or []) and kind != "theory":
                    warn_remove = True
            except ValueError:
                raise LoggedError(
                    log, "Trying to remove %s '%s', but it is not present. "
                         "Existing ones: %r", kind, remove_item, list(out_combined[kind]))
        if kind != "theory" and kind in add:
            dups = set(add.get(kind) or []).intersection(out_combined[kind]) - {"one"}
            if dups:
                raise LoggedError(
                    log, "You have added %s '%s', which was already present. If you "
                         "want to force its recomputation, you must also 'remove' it.",
                    kind, dups)
            out_combined[kind].update(add[kind])

    if warn_remove and mpi.is_main_process():
        log.warning("You are removing a prior or likelihood pdf. "
                    "Notice that if the resulting posterior is much wider "
                    "than the original one, or displaced enough, "
                    "it is probably safer to explore it directly.")

    mlprior_names_add = minuslogprior_names(add.get("prior") or [])
    chi2_names_add = [get_chi2_name(name) for name in add["likelihood"] if
                      name != "one"]
    out_combined["likelihood"].pop("one", None)

    add_theory = add.get("theory")
    if add_theory:
        if len(add["likelihood"]) == 1 and not any(
                is_derived_param(pinfo) for pinfo in add_params.values()):
            log.warning("You are adding a theory, but this does not force recomputation "
                        "of any likelihood or derived parameters unless explicitly "
                        "removed+added.")
        # Inherit from the original chain (input|output_params, renames, etc)
        added_theory = add_theory.copy()
        for theory, theory_info in out_combined["theory"].items():
            if theory in list(added_theory):
                out_combined["theory"][theory] = \
                    recursive_update(theory_info, added_theory.pop(theory))
        out_combined["theory"].update(added_theory)

    # Prepare recomputation of aggregated chi2
    # (they need to be recomputed by hand, because auto-computation won't pick up
    #  old likelihoods for a given type)
    all_types = {like: str_to_list(opts.get("type") or [])
                 for like, opts in out_combined["likelihood"].items()}
    types = set(chain(*all_types.values()))
    inv_types = {t: [like for like, like_types in all_types.items() if t in like_types]
                 for t in sorted(types)}
    add_aggregated_chi2_params(out_combined_params, types)

    # 3. Create output collection
    # Use default prefix if it exists. If it does not, produce no output by default.
    # {post: {output: None}} suppresses output, and if it's a string, updates it.
    out_prefix = info_post.get("output", info.get("output"))
    if out_prefix:
        suffix = info_post.get("suffix")
        if not suffix:
            raise LoggedError(log, "You need to provide a '%s' for your output chains.",
                              "suffix")
        out_prefix += separator_files + "post" + separator_files + suffix
    output_out = get_output(prefix=out_prefix, force=info.get("force"))
    output_out.set_lock()

    if output_out and not output_out.force and output_out.find_collections():
        raise LoggedError(log, "Found existing post-processing output with prefix %r. "
                               "Delete it manually or re-run with `force: True` "
                               "(or `-f`, `--force` from the shell).", out_prefix)
    elif output_out and output_out.force and mpi.is_main_process():
        output_out.delete_infos()
        for _file in output_out.find_collections():
            output_out.delete_file_or_folder(_file)
    info_out = deepcopy_where_possible(info)
    info_post = info_post.copy()
    info_out["post"] = info_post
    # Updated with input info and extended (updated) add info
    info_out.update(info_in)  # type: ignore
    info_post["add"] = add

    dummy_model_out = DummyModel(out_combined_params, out_combined["likelihood"],
                                 info_prior=out_combined["prior"])
    out_func_parameterization = Parameterization(out_params_with_computed)

    # TODO: check allow_renames=False?
    model_add = Model(out_params_with_computed, add["likelihood"],
                      info_prior=add.get("prior"), info_theory=out_combined["theory"],
                      packages_path=(info_post.get(packages_path_input) or
                                     info.get(packages_path_input)),
                      allow_renames=False, post=True,
                      stop_at_error=info.get('stop_at_error', False),
                      skip_unused_theories=True, dropped_theory_params=dropped_theory)
    # Remove auxiliary "one" before dumping -- 'add' *is* info_out["post"]["add"]
    add["likelihood"].pop("one")
    out_collections = [SampleCollection(dummy_model_out, output_out, name=c.name,
                                        cache_size=OutputOptions.default_post_cache_size)
                       for c in in_collections]
    # TODO: should maybe add skip/thin to out_combined, so can tell post-processed?
    output_out.check_and_dump_info(info_out, out_combined, check_compatible=False)
    collection_in = in_collections[0]
    collection_out = out_collections[0]

    last_percent = None
    known_constants = dummy_model_out.parameterization.constant_params()
    known_constants.update(dummy_model_in.parameterization.constant_params())
    missing_params = dummy_model_in.parameterization.sampled_params().keys() - set(
        collection_in.columns)
    if missing_params:
        raise LoggedError(log, "Input samples do not contain expected sampled parameter "
                               "values: %s", missing_params)

    missing_priors = set(name for name in collection_out.minuslogprior_names if
                         name not in mlprior_names_add
                         and name not in collection_in.columns)
    if _minuslogprior_1d_name in missing_priors:
        prior_recompute_1d = True
    if prior_recompute_1d:
        missing_priors.discard(_minuslogprior_1d_name)
        mlprior_names_add.insert(0, _minuslogprior_1d_name)
    prior_regenerate: Optional[Prior]
    if missing_priors and "prior" in info_in:
        # in case there are input priors that are not stored in input samples
        # e.g. when postprocessing GetDist/CosmoMC-format chains
        in_names = minuslogprior_names(info_in["prior"])
        info_prior = {piname: inf for (piname, inf), in_name in
                      zip(info_in["prior"].items(), in_names) if
                      in_name in missing_priors}
        regenerated_prior_names = minuslogprior_names(info_prior)
        missing_priors.difference_update(regenerated_prior_names)
        prior_regenerate = Prior(dummy_model_in.parameterization, info_prior)
    else:
        prior_regenerate = None
        regenerated_prior_names = None
    if missing_priors:
        raise LoggedError(log, "Missing priors: %s", missing_priors)

    mpi.sync_processes()
    output_in.check_lock()

    # 4. Main loop! Loop over input samples and adjust as required.
    if mpi.is_main_process():
        log.info("Running post-processing...")
    difflogmax: Optional[float] = None
    to_do = sum(len(c) for c in in_collections)
    weights = []
    done = 0
    last_dump_time = time.time()
    for collection_in, collection_out in zip(in_collections, out_collections):
        importance_weights = []

        def set_difflogmax():
            nonlocal difflogmax
            difflog = (collection_in[OutPar.minuslogpost].to_numpy(
                dtype=np.float64)[:len(collection_out)]
                       - collection_out[OutPar.minuslogpost].to_numpy(dtype=np.float64))
            difflogmax = np.max(difflog)
            if abs(difflogmax) < 1:
                difflogmax = 0  # keep simple when e.g. very similar
            log.debug("difflogmax: %g", difflogmax)
            if mpi.more_than_one_process():
                difflogmax = max(mpi.allgather(difflogmax))
            if mpi.is_main_process():
                log.debug("Set difflogmax: %g", difflogmax)
            _weights = np.exp(difflog - difflogmax)
            importance_weights.extend(_weights)
            collection_out.reweight(_weights)

        for i, point in collection_in.data.iterrows():
            all_params = point.to_dict()
            for p in remove_params:
                all_params.pop(p, None)
            log.debug("Point: %r", point)
            sampled = np.array([all_params[param] for param in
                                dummy_model_in.parameterization.sampled_params()])
            all_params = out_func_parameterization.to_input(all_params).copy()

            # Add/remove priors
            if prior_recompute_1d:
                priors_add = [model_add.prior.logps_internal(sampled)]
                if priors_add[0] == -np.inf:
                    continue
            else:
                priors_add = []
            if model_add.prior.external:
                priors_add.extend(model_add.prior.logps_external(all_params))

            logpriors_add = dict(zip(mlprior_names_add, priors_add))
            logpriors_new = [logpriors_add.get(name, - point.get(name, 0))
                             for name in collection_out.minuslogprior_names]
            if prior_regenerate:
                regenerated = dict(zip(regenerated_prior_names,
                                       prior_regenerate.logps_external(all_params)))
                for _i, name in enumerate(collection_out.minuslogprior_names):
                    if name in regenerated_prior_names:
                        logpriors_new[_i] = regenerated[name]

            if is_debug(log):
                log.debug("New set of priors: %r",
                          dict(zip(dummy_model_out.prior, logpriors_new)))
            if -np.inf in logpriors_new:
                continue
            # Add/remove likelihoods and/or (re-)calculate derived parameters
            loglikes_add, output_derived = model_add._loglikes_input_params(
                all_params, return_output_params=True)
            loglikes_add = dict(zip(chi2_names_add, loglikes_add))
            output_derived = dict(zip(model_add.output_params, output_derived))
            loglikes_new = [loglikes_add.get(name, -0.5 * point.get(name, 0))
                            for name in collection_out.chi2_names]
            if is_debug(log):
                log.debug("New set of likelihoods: %r",
                          dict(zip(dummy_model_out.likelihood, loglikes_new)))
                if output_derived:
                    log.debug("New set of derived parameters: %r", output_derived)
            if -np.inf in loglikes_new:
                continue
            all_params.update(output_derived)

            all_params.update(out_func_parameterization.to_derived(all_params))
            derived = {param: all_params.get(param) for param in
                       dummy_model_out.parameterization.derived_params()}
            # We need to recompute the aggregated chi2 by hand
            for type_, likes in inv_types.items():
                derived[get_chi2_name(type_)] = sum(
                    -2 * lvalue for lname, lvalue
                    in zip(collection_out.chi2_names, loglikes_new)
                    if undo_chi2_name(lname) in likes)
            if is_debug(log):
                log.debug("New derived parameters: %r",
                          {p: derived[p]
                           for p in dummy_model_out.parameterization.derived_params()
                           if p in add["params"]})
            # Save to the collection (keep old weight for now)
            weight = point.get(OutPar.weight)
            mpi.check_errors()
            if difflogmax is None and i > OutputOptions.reweight_after and \
                    time.time() - last_dump_time > OutputOptions.output_inteveral_s / 2:
                set_difflogmax()
                collection_out.out_update()

            if difflogmax is not None:
                logpost_new = sum(logpriors_new) + sum(loglikes_new)
                importance_weight = np.exp(logpost_new + point.get(OutPar.minuslogpost)
                                           - difflogmax)
                weight = weight * importance_weight
                importance_weights.append(importance_weight)
                if time.time() - last_dump_time > OutputOptions.output_inteveral_s:
                    collection_out.out_update()
                    last_dump_time = time.time()

            if weight > 0:
                collection_out.add(sampled, derived=derived.values(), weight=weight,
                                   logpriors=logpriors_new, loglikes=loglikes_new)

            # Display progress
            percent = int(np.round((i + done) / to_do * 100))
            if percent != last_percent and not percent % 5:
                last_percent = percent
                progress_bar(log, percent, " (%d/%d)" % (i + done, to_do))

        if difflogmax is None:
            set_difflogmax()
        if not collection_out.data.last_valid_index():
            raise LoggedError(
                log, "No elements in the final sample. Possible causes: "
                     "added a prior or likelihood valued zero over the full sampled "
                     "domain, or the computation of the theory failed everywhere, etc.")
        collection_out.out_update()
        weights.append(np.array(importance_weights))
        done += len(collection_in)

    assert difflogmax is not None
    points = 0
    tot_weight = 0
    min_weight = np.inf
    max_weight = -np.inf
    max_output_weight = -np.inf
    sum_w2 = 0
    points_removed = 0
    for collection_in, collection_out, importance_weights in zip(in_collections,
                                                                 out_collections,
                                                                 weights):
        output_weights = collection_out[OutPar.weight]
        points += len(collection_out)
        tot_weight += np.sum(output_weights)
        points_removed += len(importance_weights) - len(output_weights)
        min_weight = min(min_weight, np.min(importance_weights))
        max_weight = max(max_weight, np.max(importance_weights))
        max_output_weight = max(max_output_weight, np.max(output_weights))
        sum_w2 += np.dot(output_weights, output_weights)

    (tot_weights, min_weights, max_weights, max_output_weights, sum_w2s, points_s,
     points_removed_s) = mpi.zip_gather(
        [tot_weight, min_weight, max_weight, max_output_weight, sum_w2,
         points, points_removed])

    if mpi.is_main_process():
        output_out.clear_lock()
        log.info("Finished! Final number of distinct sample points: %s", sum(points_s))
        log.info("Importance weight range: %.4g -- %.4g",
                 min(min_weights), max(max_weights))
        if sum(points_removed_s):
            log.info("Points deleted due to zero weight: %s", sum(points_removed_s))
        log.info("Effective number of single samples if independent (sum w)/max(w): %s",
                 int(sum(tot_weights) / max(max_output_weights)))
        log.info(
            "Effective number of weighted samples if independent (sum w)^2/sum(w^2): "
            "%s", int(sum(tot_weights) ** 2 / sum(sum_w2s)))
    products: PostResultDict = {"sample": value_or_list(out_collections),
                                "stats": {'min_importance_weight': (min(min_weights) /
                                                                    max(max_weights)),
                                          'points_removed': sum(points_removed_s),
                                          'tot_weight': sum(tot_weights),
                                          'max_weight': max(max_output_weights),
                                          'sum_w2': sum(sum_w2s),
                                          'points': sum(points_s)},
                                "logpost_weight_offset": difflogmax,
                                "weights": value_or_list(weights)}
    return PostTuple(info=out_combined, products=products)
Exemplo n.º 28
0
from itertools import permutations
from typing import Mapping, Sequence, Any, List, TypeVar, Optional, Union, \
    Iterable, Set, Dict
from types import ModuleType
from inspect import cleandoc, getfullargspec
from ast import parse
import traceback

# Local
from cobaya.conventions import cobaya_package, subfolders, kinds, \
    packages_path_config_file, packages_path_env, packages_path_arg, dump_sort_cosmetic
from cobaya.log import LoggedError, HasLogger, get_logger
from cobaya.typing import Kind

# Set up logger
log = get_logger(__name__)


def str_to_list(x) -> List:
    """
    Makes sure that the input is a list of strings (could be string).
    """
    return [x] if isinstance(x, str) else x


def ensure_dict(iterable_or_dict):
    """
    For iterables, returns dict with elements as keys and null values.
    """
    if not isinstance(iterable_or_dict, Mapping):
        return dict.fromkeys(iterable_or_dict)
Exemplo n.º 29
0
def run_script(args=None):
    warn_deprecation()
    import argparse
    parser = argparse.ArgumentParser(prog="cobaya run",
                                     description="Cobaya's run script.")
    parser.add_argument("input_file",
                        action="store",
                        metavar="input_file.yaml",
                        help="An input file to run.")
    parser.add_argument("-" + packages_path_arg[0],
                        "--" + packages_path_arg_posix,
                        action="store",
                        metavar="/packages/path",
                        default=None,
                        help="Path where external packages were installed.")
    # MARKED FOR DEPRECATION IN v3.0
    modules = "modules"
    parser.add_argument("-" + modules[0],
                        "--" + modules,
                        action="store",
                        required=False,
                        metavar="/packages/path",
                        default=None,
                        help="To be deprecated! "
                        "Alias for %s, which should be used instead." %
                        packages_path_arg_posix)
    # END OF DEPRECATION BLOCK -- CONTINUES BELOW!
    parser.add_argument("-" + "o",
                        "--" + "output",
                        action="store",
                        metavar="/some/path",
                        default=None,
                        help="Path and prefix for the text output.")
    parser.add_argument("-" + "d",
                        "--" + "debug",
                        action="store_true",
                        help="Produce verbose debug output.")
    continuation = parser.add_mutually_exclusive_group(required=False)
    continuation.add_argument(
        "-" + "r",
        "--" + "resume",
        action="store_true",
        help="Resume an existing chain if it has similar info "
        "(fails otherwise).")
    continuation.add_argument("-" + "f",
                              "--" + "force",
                              action="store_true",
                              help="Overwrites previous output, if it exists "
                              "(use with care!)")
    parser.add_argument("--%s" % "test",
                        action="store_true",
                        help="Initialize model and sampler, and exit.")
    parser.add_argument("--version", action="version", version=get_version())
    parser.add_argument("--no-mpi",
                        action='store_true',
                        help="disable MPI when mpi4py installed but MPI does "
                        "not actually work")
    arguments = parser.parse_args(args)

    # MARKED FOR DEPRECATION IN v3.0
    if arguments.modules is not None:
        logger_setup()
        logger = get_logger("run")
        logger.warning(
            "*DEPRECATION*: -m/--modules will be deprecated in favor of "
            "-%s/--%s in the next version. Please, use that one instead.",
            packages_path_arg[0], packages_path_arg_posix)
        # BEHAVIOUR TO BE REPLACED BY ERROR:
        if getattr(arguments, packages_path_arg) is None:
            setattr(arguments, packages_path_arg, arguments.modules)
    del arguments.modules
    # END OF DEPRECATION BLOCK
    info = arguments.input_file
    del arguments.input_file
    run(info, **arguments.__dict__)
Exemplo n.º 30
0
def run(
    info_or_yaml_or_file: Union[InputDict, str, os.PathLike],
    packages_path: Optional[str] = None,
    output: Union[str, LiteralFalse, None] = None,
    debug: Union[bool, int, None] = None,
    stop_at_error: Optional[bool] = None,
    resume: bool = False,
    force: bool = False,
    no_mpi: bool = False,
    test: bool = False,
    override: Optional[InputDict] = None,
) -> Union[InfoSamplerTuple, PostTuple]:
    """
    Run from an input dictionary, file name or yaml string, with optional arguments
    to override settings in the input as needed.

    :param info_or_yaml_or_file: input options dictionary, yaml file, or yaml text
    :param packages_path: path where external packages were installed
    :param output: path name prefix for output files, or False for no file output
    :param debug: true for verbose debug output, or a specific logging level
    :param stop_at_error: stop if an error is raised
    :param resume: continue an existing run
    :param force: overwrite existing output if it exists
    :param no_mpi: run without MPI
    :param test: only test initialization rather than actually running
    :param override: option dictionary to merge into the input one, overriding settings
       (but with lower precedence than the explicit keyword arguments)
    :return: (updated_info, sampler) tuple of options dictionary and Sampler instance,
              or (updated_info, results) if using "post" post-processing
    """

    # This function reproduces the model-->output-->sampler pipeline one would follow
    # when instantiating by hand, but alters the order to performs checks and dump info
    # as early as possible, e.g. to check if resuming possible or `force` needed.
    if no_mpi or test:
        mpi.set_mpi_disabled()

    with mpi.ProcessState("run"):
        info: InputDict = load_info_overrides(info_or_yaml_or_file, debug,
                                              stop_at_error, packages_path,
                                              override)

        if test:
            info["test"] = True
        # If any of resume|force given as cmd args, ignore those in the input file
        if resume or force:
            if resume and force:
                raise ValueError("'rename' and 'force' are exclusive options")
            info["resume"] = bool(resume)
            info["force"] = bool(force)
        if info.get("post"):
            if isinstance(output, str) or output is False:
                info["post"]["output"] = output or None
            return post(info)

        if isinstance(output, str) or output is False:
            info["output"] = output or None
        logger_setup(info.get("debug"), info.get("debug_file"))
        logger_run = get_logger(run.__name__)
        # MARKED FOR DEPRECATION IN v3.0
        # BEHAVIOUR TO BE REPLACED BY ERROR:
        check_deprecated_modules_path(info)
        # END OF DEPRECATION BLOCK
        # 1. Prepare output driver, if requested by defining an output_prefix
        # GetDist needs to know the original sampler, so don't overwrite if minimizer
        try:
            which_sampler = list(info["sampler"])[0]
        except (KeyError, TypeError):
            raise LoggedError(
                logger_run,
                "You need to specify a sampler using the 'sampler' key "
                "as e.g. `sampler: {mcmc: None}.`")
        infix = "minimize" if which_sampler == "minimize" else None
        with get_output(prefix=info.get("output"),
                        resume=info.get("resume"),
                        force=info.get("force"),
                        infix=infix) as out:
            # 2. Update the input info with the defaults for each component
            updated_info = update_info(info)
            if is_debug(logger_run):
                # Dump only if not doing output
                # (otherwise, the user can check the .updated file)
                if not out and mpi.is_main_process():
                    logger_run.info(
                        "Input info updated with defaults (dumped to YAML):\n%s",
                        yaml_dump(sort_cosmetic(updated_info)))
            # 3. If output requested, check compatibility if existing one, and dump.
            # 3.1 First: model only
            out.check_and_dump_info(info,
                                    updated_info,
                                    cache_old=True,
                                    ignore_blocks=["sampler"])
            # 3.2 Then sampler -- 1st get the last sampler mentioned in the updated.yaml
            # TODO: ideally, using Minimizer would *append* to the sampler block.
            #       Some code already in place, but not possible at the moment.
            try:
                last_sampler = list(updated_info["sampler"])[-1]
                last_sampler_info = {
                    last_sampler: updated_info["sampler"][last_sampler]
                }
            except (KeyError, TypeError):
                raise LoggedError(logger_run, "No sampler requested.")
            sampler_name, sampler_class = get_sampler_name_and_class(
                last_sampler_info)
            check_sampler_info((out.reload_updated_info(use_cache=True)
                                or {}).get("sampler"),
                               updated_info["sampler"],
                               is_resuming=out.is_resuming())
            # Dump again, now including sampler info
            out.check_and_dump_info(info, updated_info, check_compatible=False)
            # Check if resumable run
            sampler_class.check_force_resume(
                out, info=updated_info["sampler"][sampler_name])
            # 4. Initialize the posterior and the sampler
            with Model(updated_info["params"],
                       updated_info["likelihood"],
                       updated_info.get("prior"),
                       updated_info.get("theory"),
                       packages_path=info.get("packages_path"),
                       timing=updated_info.get("timing"),
                       allow_renames=False,
                       stop_at_error=info.get("stop_at_error",
                                              False)) as model:
                # Re-dump the updated info, now containing parameter routes and version
                updated_info = recursive_update(updated_info, model.info())
                out.check_and_dump_info(None,
                                        updated_info,
                                        check_compatible=False)
                sampler = sampler_class(
                    updated_info["sampler"][sampler_name],
                    model,
                    out,
                    name=sampler_name,
                    packages_path=info.get("packages_path"))
                # Re-dump updated info, now also containing updates from the sampler
                updated_info["sampler"][sampler_name] = \
                    recursive_update(updated_info["sampler"][sampler_name],
                                     sampler.info())
                out.check_and_dump_info(None,
                                        updated_info,
                                        check_compatible=False)
                mpi.sync_processes()
                if info.get("test", False):
                    logger_run.info(
                        "Test initialization successful! "
                        "You can probably run now without `--%s`.", "test")
                    return InfoSamplerTuple(updated_info, sampler)
                # Run the sampler
                sampler.run()

    return InfoSamplerTuple(updated_info, sampler)