Ejemplo n.º 1
0
def test_mcmc_sync():
    info: InputDict = yaml_load(yaml)
    logger.info('Test end synchronization')

    if mpi.rank() == 1:
        max_samples = 200
    else:
        max_samples = 600
    # simulate asynchronous ending sampling loop
    info['sampler']['mcmc'] = {'max_samples': max_samples}

    updated_info, sampler = run(info)
    assert len(sampler.products()["sample"]) == max_samples

    logger.info('Test error synchronization')
    if mpi.rank() == 0:
        info['sampler']['mcmc'] = {'max_samples': 'none'}  # 'none' not valid
        with NoLogging(logging.ERROR), pytest.raises(TypeError):
            run(info)
    else:
        with pytest.raises(mpi.OtherProcessError):
            run(info)

    logger.info('Test one-process hang abort')

    aborted = False

    def test_abort():
        nonlocal aborted
        aborted = True

    # test error converted into MPI_ABORT after timeout
    # noinspection PyTypeChecker
    with pytest.raises(
        (LoggedError, mpi.OtherProcessError)), NoLogging(logging.ERROR):
        with mpi.ProcessState('test',
                              time_out_seconds=0.5,
                              timeout_abort_proc=test_abort):
            if mpi.rank() != 1:
                time.sleep(0.6)  # fake hang
            else:
                raise LoggedError(logger, 'Expected test error')
    if mpi.rank() == 1:
        assert aborted
Ejemplo n.º 2
0
    def run(self):
        """
        Runs the sampler.
        """
        self.mpi_info("Sampling!" + (
            " (NB: no accepted step will be saved until %d burn-in samples " %
            self.burn_in.value +
            "have been obtained)" if self.burn_in.value else ""))
        self.n_steps_raw = 0
        last_output: float = 0
        last_n = self.n()
        state_check_every = 1
        with mpi.ProcessState(self) as state:
            while last_n < self.max_samples and not self.converged:
                self.get_new_sample()
                self.n_steps_raw += 1
                if self.output_every.unit:
                    # if output_every in sec, print some info
                    # and dump at fixed time intervals
                    now = datetime.datetime.now()
                    now_sec = now.timestamp()
                    if now_sec >= last_output + self.output_every.value:
                        self.do_output(now)
                        last_output = now_sec
                        state.check_error()
                if self.current_point.weight == 1:
                    # have added new point
                    # Callback function
                    n = self.n()
                    if n != last_n:
                        # and actually added
                        last_n = n
                        if (self.callback_function
                                and not (max(n, 1) % self.callback_every.value)
                                and self.current_point.weight == 1):
                            self.callback_function_callable(self)
                            self.last_point_callback = len(self.collection)

                        if more_than_one_process():
                            # Checking convergence and (optionally) learning
                            # the covmat of the proposal
                            if self.check_ready() and state.set(
                                    mpi.State.READY):
                                self.log.info(self._msg_ready +
                                              " (waiting for the rest...)")
                            if state.all_ready():
                                self.mpi_info("All chains are r%s",
                                              self._msg_ready[1:])
                                self.check_convergence_and_learn_proposal()
                                self.i_learn += 1
                        else:
                            if self.check_ready():
                                self.log.debug(self._msg_ready)
                                self.check_convergence_and_learn_proposal()
                                self.i_learn += 1
                elif self.current_point.weight % state_check_every == 0:
                    state.check_error()
                    # more frequent checks near beginning
                    state_check_every = min(10, state_check_every + 1)

            if last_n == self.max_samples:
                self.log.info(
                    "Reached maximum number of accepted steps allowed (%s). "
                    "Stopping.", self.max_samples)

            # Write the last batch of samples ( < output_every (not in sec))
            self.collection.out_update()

        ns = mpi.gather(self.n())
        self.mpi_info("Sampling complete after %d accepted steps.", sum(ns))
Ejemplo n.º 3
0
def run(
    info_or_yaml_or_file: Union[InputDict, str, os.PathLike],
    packages_path: Optional[str] = None,
    output: Union[str, LiteralFalse, None] = None,
    debug: Union[bool, int, None] = None,
    stop_at_error: Optional[bool] = None,
    resume: bool = False,
    force: bool = False,
    no_mpi: bool = False,
    test: bool = False,
    override: Optional[InputDict] = None,
) -> Union[InfoSamplerTuple, PostTuple]:
    """
    Run from an input dictionary, file name or yaml string, with optional arguments
    to override settings in the input as needed.

    :param info_or_yaml_or_file: input options dictionary, yaml file, or yaml text
    :param packages_path: path where external packages were installed
    :param output: path name prefix for output files, or False for no file output
    :param debug: true for verbose debug output, or a specific logging level
    :param stop_at_error: stop if an error is raised
    :param resume: continue an existing run
    :param force: overwrite existing output if it exists
    :param no_mpi: run without MPI
    :param test: only test initialization rather than actually running
    :param override: option dictionary to merge into the input one, overriding settings
       (but with lower precedence than the explicit keyword arguments)
    :return: (updated_info, sampler) tuple of options dictionary and Sampler instance,
              or (updated_info, results) if using "post" post-processing
    """

    # This function reproduces the model-->output-->sampler pipeline one would follow
    # when instantiating by hand, but alters the order to performs checks and dump info
    # as early as possible, e.g. to check if resuming possible or `force` needed.
    if no_mpi or test:
        mpi.set_mpi_disabled()

    with mpi.ProcessState("run"):
        info: InputDict = load_info_overrides(info_or_yaml_or_file, debug,
                                              stop_at_error, packages_path,
                                              override)

        if test:
            info["test"] = True
        # If any of resume|force given as cmd args, ignore those in the input file
        if resume or force:
            if resume and force:
                raise ValueError("'rename' and 'force' are exclusive options")
            info["resume"] = bool(resume)
            info["force"] = bool(force)
        if info.get("post"):
            if isinstance(output, str) or output is False:
                info["post"]["output"] = output or None
            return post(info)

        if isinstance(output, str) or output is False:
            info["output"] = output or None
        logger_setup(info.get("debug"), info.get("debug_file"))
        logger_run = get_logger(run.__name__)
        # MARKED FOR DEPRECATION IN v3.0
        # BEHAVIOUR TO BE REPLACED BY ERROR:
        check_deprecated_modules_path(info)
        # END OF DEPRECATION BLOCK
        # 1. Prepare output driver, if requested by defining an output_prefix
        # GetDist needs to know the original sampler, so don't overwrite if minimizer
        try:
            which_sampler = list(info["sampler"])[0]
        except (KeyError, TypeError):
            raise LoggedError(
                logger_run,
                "You need to specify a sampler using the 'sampler' key "
                "as e.g. `sampler: {mcmc: None}.`")
        infix = "minimize" if which_sampler == "minimize" else None
        with get_output(prefix=info.get("output"),
                        resume=info.get("resume"),
                        force=info.get("force"),
                        infix=infix) as out:
            # 2. Update the input info with the defaults for each component
            updated_info = update_info(info)
            if is_debug(logger_run):
                # Dump only if not doing output
                # (otherwise, the user can check the .updated file)
                if not out and mpi.is_main_process():
                    logger_run.info(
                        "Input info updated with defaults (dumped to YAML):\n%s",
                        yaml_dump(sort_cosmetic(updated_info)))
            # 3. If output requested, check compatibility if existing one, and dump.
            # 3.1 First: model only
            out.check_and_dump_info(info,
                                    updated_info,
                                    cache_old=True,
                                    ignore_blocks=["sampler"])
            # 3.2 Then sampler -- 1st get the last sampler mentioned in the updated.yaml
            # TODO: ideally, using Minimizer would *append* to the sampler block.
            #       Some code already in place, but not possible at the moment.
            try:
                last_sampler = list(updated_info["sampler"])[-1]
                last_sampler_info = {
                    last_sampler: updated_info["sampler"][last_sampler]
                }
            except (KeyError, TypeError):
                raise LoggedError(logger_run, "No sampler requested.")
            sampler_name, sampler_class = get_sampler_name_and_class(
                last_sampler_info)
            check_sampler_info((out.reload_updated_info(use_cache=True)
                                or {}).get("sampler"),
                               updated_info["sampler"],
                               is_resuming=out.is_resuming())
            # Dump again, now including sampler info
            out.check_and_dump_info(info, updated_info, check_compatible=False)
            # Check if resumable run
            sampler_class.check_force_resume(
                out, info=updated_info["sampler"][sampler_name])
            # 4. Initialize the posterior and the sampler
            with Model(updated_info["params"],
                       updated_info["likelihood"],
                       updated_info.get("prior"),
                       updated_info.get("theory"),
                       packages_path=info.get("packages_path"),
                       timing=updated_info.get("timing"),
                       allow_renames=False,
                       stop_at_error=info.get("stop_at_error",
                                              False)) as model:
                # Re-dump the updated info, now containing parameter routes and version
                updated_info = recursive_update(updated_info, model.info())
                out.check_and_dump_info(None,
                                        updated_info,
                                        check_compatible=False)
                sampler = sampler_class(
                    updated_info["sampler"][sampler_name],
                    model,
                    out,
                    name=sampler_name,
                    packages_path=info.get("packages_path"))
                # Re-dump updated info, now also containing updates from the sampler
                updated_info["sampler"][sampler_name] = \
                    recursive_update(updated_info["sampler"][sampler_name],
                                     sampler.info())
                out.check_and_dump_info(None,
                                        updated_info,
                                        check_compatible=False)
                mpi.sync_processes()
                if info.get("test", False):
                    logger_run.info(
                        "Test initialization successful! "
                        "You can probably run now without `--%s`.", "test")
                    return InfoSamplerTuple(updated_info, sampler)
                # Run the sampler
                sampler.run()

    return InfoSamplerTuple(updated_info, sampler)